Skip to content

Commit

Permalink
FIX Int8 check for torchao v0.7.0 (#2284)
Browse files Browse the repository at this point in the history
At one point, we need to perform a check for the quantization dtype.
This used to rely on the layout_tensor attribute, which was renamed to
tensor_impl. The code now checks both attributes.
  • Loading branch information
BenjaminBossan authored Dec 18, 2024
1 parent ae55fdc commit c1fe810
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/peft/tuners/lora/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def _check_dtype_supported(self):
# TODO: Not required once int4_weight_only is properly supported by torchao
base_layer = self.get_base_layer()
weight = base_layer.weight
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
# pytest tests/test_gpu_examples.py::PeftTorchaoGPUTests::test_causal_lm_training_single_gpu_torchao_0_int8_weight_only
if (
# torchao 0.7.0+
(hasattr(weight, "tensor_impl") and (weight.tensor_impl.data.dtype != torch.int8))
or
# torchao < 0.7.0
(hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8))
):
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down

0 comments on commit c1fe810

Please sign in to comment.