Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Sep 5, 2023
1 parent 4eb412f commit 19364f4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def unload(self):
def forward(self, *args: Any, **kwargs: Any):
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
return super().forward(*args, **kwargs)
return self.model.forward(*args, **kwargs)

if self.training:
raise ValueError("Multiple LoRAs in the same batch isn't supported during training")
Expand All @@ -694,7 +694,7 @@ def forward(self, *args: Any, **kwargs: Any):
def generate(self, **kwargs: Any):
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
return super().forward(**kwargs)
return self.model.generate(**kwargs)

if self.training:
raise ValueError("Multiple LoRAs in the same batch isn't supported during training")
Expand Down

0 comments on commit 19364f4

Please sign in to comment.