From 1e71585d8af2206c5917b9c7f4a1c922c277053f Mon Sep 17 00:00:00 2001 From: zhangshengdong29 <435878393@qq.com> Date: Thu, 30 Nov 2023 23:24:58 +0800 Subject: [PATCH] Megatron distributed parallel linear LoRA (#1092) Adds option to use Megatron's ColumnParallelLinear and RowParallelLinear for LoRA linear layers, leading to improved performance when using LoRA with Megatron. --- src/peft/tuners/lora/config.py | 26 +++++ src/peft/tuners/lora/layer.py | 3 + src/peft/tuners/lora/model.py | 27 +++++ src/peft/tuners/lora/tp_layer.py | 158 +++++++++++++++++++++++++++++ tests/test_lora_megatron.py | 167 +++++++++++++++++++++++++++++++ 5 files changed, 381 insertions(+) create mode 100644 src/peft/tuners/lora/tp_layer.py create mode 100644 tests/test_lora_megatron.py diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 0dcca5c1e6..53269ebb8d 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -141,6 +141,32 @@ class LoraConfig(PeftConfig): ) }, ) + megatron_config: Optional[dict] = field( + default=None, + metadata={ + "help": ( + "The TransformerConfig from Megatron, it is used to create LoRA's parallel linear layer." + "You can get it like this, `core_transformer_config_from_args(get_args())`, " + "this two functions are from Megatron." + "You need to specify this parameter when you want to loraize the ColumnParallelLinear and " + "RowParallelLinear layers of megatron." + "It should be noted that we may not be able to use the `save_pretrained` and `from_pretrained` " + "functions, because TransformerConfig may not necessarily be serialized." + "But when using megatron, we can use `get_peft_model_state_dict` function and " + "megatron's framework, they can also save and load models and configurations." + ) + }, + ) + megatron_core: Optional[str] = field( + default="megatron.core", + metadata={ + "help": ( + "The core module from Megatron, it is used to judge and create LoRA's parallel linear layer. " + "It only needs to be passed in when you need to use your own modified megatron core module. " + "Otherwise, it will use the default value `megatron.core`. " + ) + }, + ) # dict type is used when loading config.json loftq_config: Union[LoftQConfig, dict] = field( default_factory=dict, diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index cf97108c87..3219ca1e47 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -62,6 +62,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): # QuantLinear in_features, out_features = base_layer.infeatures, base_layer.outfeatures + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_size else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 6e0a64187a..4f6538e912 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import math import operator import re @@ -259,6 +260,10 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): else: target_base_layer = target + megatron_core = None + if lora_config.megatron_config: + megatron_core = importlib.import_module(lora_config.megatron_core) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() eightbit_kwargs.update( @@ -300,6 +305,28 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, **kwargs) + elif megatron_core and isinstance( + target_base_layer, + (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), + ): + from .tp_layer import LoraParallelLinear + + megatron_kwargs = kwargs.copy() + megatron_config = lora_config.megatron_config + if isinstance(megatron_config, dict): + transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig + megatron_config = transformer_config_class(**lora_config.megatron_config) + megatron_kwargs["megatron_config"] = megatron_config + if megatron_kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` " + "or `RowParallelLinear`. " + "Setting fan_in_fan_out to False." + ) + megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + new_module = LoraParallelLinear( + base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs + ) elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: warnings.warn( diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py new file mode 100644 index 0000000000..676430cf38 --- /dev/null +++ b/src/peft/tuners/lora/tp_layer.py @@ -0,0 +1,158 @@ +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.init as init + +from .layer import LoraLayer + + +class LoraParallelLinear(nn.Module, LoraLayer): + """ + When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes + consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear + layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B, + while lora_A is still a complete linear layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + backend, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer=base_layer) + + self.backend = backend + self.is_paralle_a = isinstance(base_layer, backend.RowParallelLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + + megatron_config = kwargs["megatron_config"] + parallel_linear_kwargs = {"megatron_config": megatron_config} + init_method = init.xavier_normal_ + if hasattr(megatron_config, "init_method"): + init_method = megatron_config.init_method + input_is_parallel = True + gather_output = False + if isinstance(base_layer, self.backend.RowParallelLinear): + input_is_parallel = base_layer.input_is_parallel + else: + gather_output = base_layer.gather_output + self.update_layer( + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + init_method, + input_is_parallel, + gather_output, + **parallel_linear_kwargs, + ) + + self.is_target_conv_1d_layer = False + + def update_layer( + self, + adapter_name, + r, + lora_alpha, + lora_dropout, + init_lora_weights, + init_method=init.xavier_normal_, + input_is_parallel=True, + gather_output=False, + **parallel_linear_kwargs, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + megatron_config = parallel_linear_kwargs["megatron_config"] + # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow + megatron_config.params_dtype = torch.float32 + if self.is_paralle_a: + lora_a = self.backend.RowParallelLinear( + input_size=self.in_features, + output_size=r, + bias=False, + input_is_parallel=input_is_parallel, + skip_bias_add=True, + init_method=init_method, + config=megatron_config, + ) + lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) + else: + lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) + lora_b = self.backend.ColumnParallelLinear( + input_size=r, + output_size=self.out_features, + bias=False, + gather_output=gather_output, + init_method=init_method, + config=megatron_config, + ) + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + previous_dtype = x.dtype + # If weight is used for matrix multiplication here, the final aggregation operation of the original + # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the + # output of the original parallel_linear layer. + if self.disable_adapters: + if self.merged: + self.unmerge() + result, bias = self.base_layer(x, *args, **kwargs) + elif self.merged: + result, bias = self.base_layer(x, *args, **kwargs) + else: + result, bias = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + + lora_result = lora_A(dropout(x)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_B(lora_result) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias diff --git a/tests/test_lora_megatron.py b/tests/test_lora_megatron.py new file mode 100644 index 0000000000..80d0f43010 --- /dev/null +++ b/tests/test_lora_megatron.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import importlib +import os +import unittest + +import torch +import torch.nn.init as init + +from peft import LoraConfig, PeftModel, get_peft_model, get_peft_model_state_dict + + +def is_megatron_available() -> bool: + return importlib.util.find_spec("megatron") is not None + + +if is_megatron_available(): + from megatron.core import parallel_state, tensor_parallel + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + from megatron.core.transformer.module import MegatronModule + from megatron.core.transformer.transformer_config import TransformerConfig + + world_size = 1 + rank = 0 + + def initialize_distributed(): + print(f"Initializing torch.distributed with rank: {rank}, world_size: {world_size}") + torch.cuda.set_device(0) + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6001") + init_method += master_ip + ":" + master_port + torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank, init_method=init_method) + + def destroy_model_parallel(): + parallel_state.destroy_model_parallel() + torch.distributed.barrier() + + def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + ): + parallel_state.destroy_model_parallel() + if not torch.distributed.is_initialized(): + initialize_distributed() + parallel_state.initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank, + ) + + class DummyModule(MegatronModule): + def __init__(self, config: TransformerConfig): + super().__init__(config) + self.linear = tensor_parallel.ColumnParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + gather_output=False, + ) + self.lm_head = tensor_parallel.RowParallelLinear( + input_size=10, + output_size=10, + config=config, + init_method=init.xavier_normal_, + bias=False, + input_is_parallel=True, + ) + + def forward(self, input): + x = self.linear(input)[0] + x = self.lm_head(x)[0] + return x + + class TestMegatronLora(unittest.TestCase): + def setUp(self): + initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + transformer_config = { + "num_layers": 2, + "hidden_size": 12, + "num_attention_heads": 4, + "use_cpu_initialization": True, + } + config = TransformerConfig(**transformer_config) + self.megatron_module = DummyModule(config=config).cuda() + self.dummy_module = copy.deepcopy(self.megatron_module).cuda() + + lora_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=64, + bias="none", + target_modules=["linear", "lm_head"], + megatron_config=config, + megatron_core="megatron.core", + ) + self.megatron_module = get_peft_model(self.megatron_module, lora_config) + + def tearDown(self): + destroy_model_parallel() + + def test_megatron_lora_module(self): + megatron_module = self.megatron_module + self.assertTrue(isinstance(megatron_module, PeftModel)) + + for name, module in megatron_module.named_modules(): + if name.endswith("linear"): + self.assertTrue(hasattr(module, "lora_A")) + self.assertTrue(hasattr(module, "lora_B")) + if name.endswith("linear.lora_A.default"): + self.assertTrue(isinstance(module, torch.nn.Linear)) + if name.endswith("linear.lora_B.default"): + self.assertTrue(isinstance(module, tensor_parallel.ColumnParallelLinear)) + + if name.endswith("lm_head.lora_A.default"): + self.assertTrue(isinstance(module, tensor_parallel.RowParallelLinear)) + if name.endswith("lm_head.lora_B.default"): + self.assertTrue(isinstance(module, torch.nn.Linear)) + + def test_forward(self): + x = torch.ones((2, 4, 10)).cuda() + megatron_module_result = self.megatron_module(x) + dummt_module_result = self.dummy_module(x) + + # Because lora_B is initialized with 0, the forward results of two models should be equal before backward. + self.assertTrue(megatron_module_result.equal(dummt_module_result)) + + def test_backward(self): + optimizer = torch.optim.AdamW(self.megatron_module.parameters()) + loss_fn = torch.nn.CrossEntropyLoss() + + x = torch.randn(2, 4, 10, requires_grad=True).cuda() + label = torch.randint(10, (2 * 4,)).cuda() + + output = self.megatron_module(x) + output = output.reshape(2 * 4, 10) + loss = loss_fn(output, label) + + loss.backward() + optimizer.step() + + def test_get_peft_model_state_dict(self): + peft_state_dict = get_peft_model_state_dict(self.megatron_module) + + for key in peft_state_dict.keys(): + self.assertTrue("lora" in key)