diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 88aadf4f01..1a2d8297af 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -360,6 +360,20 @@ def fsdp_auto_wrap_policy(model): from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + default_transformer_cls_names_to_wrap = ( + ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" + ) + transformer_cls_names_to_wrap = os.environ.get( + "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap + ).split(",") + transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} + for layer_class in transformer_cls_names_to_wrap: + transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + def lambda_policy_fn(module): if ( len(list(module.named_children())) == 0 @@ -372,14 +386,7 @@ def lambda_policy_fn(module): lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, - transformer_layer_cls=( - PrefixEncoder, - PromptEncoder, - PromptEmbedding, - FullyShardedDataParallelPlugin.get_module_class_from_name( - model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") - ), - ), + transformer_layer_cls=transformer_cls_to_wrap, ) auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])