Skip to content

Commit

Permalink
account for the new merged/unmerged weight to perform the quantizatio…
Browse files Browse the repository at this point in the history
…n again (#1370)
  • Loading branch information
pacman100 authored Jan 18, 2024
1 parent 62237dc commit ebbff40
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
Expand All @@ -292,6 +293,8 @@ def unmerge(self) -> None:
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
Expand Down

0 comments on commit ebbff40

Please sign in to comment.