diff --git a/src/peft/tuners/lora/torchao.py b/src/peft/tuners/lora/torchao.py index 05ffbc3445..5e7240a053 100644 --- a/src/peft/tuners/lora/torchao.py +++ b/src/peft/tuners/lora/torchao.py @@ -43,7 +43,14 @@ def _check_dtype_supported(self): # TODO: Not required once int4_weight_only is properly supported by torchao base_layer = self.get_base_layer() weight = base_layer.weight - if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8): + # pytest tests/test_gpu_examples.py::PeftTorchaoGPUTests::test_causal_lm_training_single_gpu_torchao_0_int8_weight_only + if ( + # torchao 0.7.0+ + (hasattr(weight, "tensor_impl") and (weight.tensor_impl.data.dtype != torch.int8)) + or + # torchao < 0.7.0 + (hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8)) + ): raise ValueError(f"{type(self).__name__} only supports int8 weights for now.") def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: