From ebbff4023ad276cbcb2466fd7e99be7d3ae0ae11 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 18 Jan 2024 20:09:09 +0530 Subject: [PATCH] account for the new merged/unmerged weight to perform the quantization again (#1370) --- src/peft/tuners/lora/bnb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index f4f7194f11..3bea46ee50 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -267,7 +267,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( weight.device ) @@ -292,6 +293,8 @@ def unmerge(self) -> None: kwargs = weight.__dict__ lora_data = self.get_delta_weight(active_adapter) w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to( weight.device )