From 03679f33f7f9d9da7e99d5b6fd40b4548019a3c3 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 6 Jun 2024 13:50:13 -0700 Subject: [PATCH 01/12] refactor(lora): reorganize the code in BatchLoraWeights.load This function was a bit hard to understand as there were multiple list comprehensions with almost same looping logic. So, merged all of them into a single for loop so for improved clarity. --- server/lorax_server/adapters/lora.py | 77 +++++++++------------------- 1 file changed, 24 insertions(+), 53 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index eea5301e7..9b1a1c17c 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -251,65 +251,36 @@ def load( first_weights = list(adapter_weights.values())[0] device = first_weights.weights_a.device segment_indices = meta.segment_indices + use_sgmv = prefill or max_rank > BGMV_MAX_RANK - lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} - lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - - max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights) - - if prefill or max_rank > BGMV_MAX_RANK: - use_sgmv = True - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - else: - use_sgmv = False - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - - adapter_index_configs = { - idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights - } - - adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} - - rank_indices = defaultdict(list) + lora_a, lora_b, adapter_index_configs, adapter_to_segment = {}, {}, {}, {} + lora_a_ptr, lora_b_ptr = [], [] + max_rank, rank_indices = 0, defaultdict(list) for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx not in adapter_weights: - continue - rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + adapter_to_segment[adapter_idx] = segment_idx + if adapter_idx in adapter_weights: + adapter_weight = adapter_weights[adapter_idx] + adapter_index_configs[adapter_idx] = adapter_weight.config + max_rank = max(max_rank, adapter_weight.lora_a_r) + rank_indices[adapter_weight.lora_a_r].append(segment_idx) + lora_a[adapter_idx] = adapter_weight.weights_a + lora_b[adapter_idx] = adapter_weight.weights_b + lora_a_ptr.append( + (adapter_weight.weights_a if use_sgmv else adapter_weight.weigths_a_t).data_ptr() + ) + lora_b_ptr.append( + (adapter_weight.weights_b if use_sgmv else adapter_weight.weigths_b_t).data_ptr() + ) + else: + lora_a_ptr.append(EMPTY_TENSOR.data_ptr()) + lora_b_ptr.append(EMPTY_TENSOR.data_ptr()) + lora_a_ptr = torch.tensor(lora_a_ptr, dtype=torch.int64, device=device) + lora_b_ptr = torch.tensor(lora_b_ptr, dtype=torch.int64, device=device) if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] for head_index in prefill_head_indices: - # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + # j cannot go out of bounds as that would mean there are tokens without segments if head_index < meta.adapter_segments[j]: prefill_head_segment_ends[-1] += 1 else: From ddc996e17458253d2815c716d102a0a749431e18 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 6 Jun 2024 14:11:41 -0700 Subject: [PATCH 02/12] refactor(lora): fix max_rank issue --- server/lorax_server/adapters/lora.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 9b1a1c17c..54069c3a0 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -251,10 +251,8 @@ def load( first_weights = list(adapter_weights.values())[0] device = first_weights.weights_a.device segment_indices = meta.segment_indices - use_sgmv = prefill or max_rank > BGMV_MAX_RANK lora_a, lora_b, adapter_index_configs, adapter_to_segment = {}, {}, {}, {} - lora_a_ptr, lora_b_ptr = [], [] max_rank, rank_indices = 0, defaultdict(list) for segment_idx, adapter_idx in enumerate(segment_indices): adapter_to_segment[adapter_idx] = segment_idx @@ -265,6 +263,12 @@ def load( rank_indices[adapter_weight.lora_a_r].append(segment_idx) lora_a[adapter_idx] = adapter_weight.weights_a lora_b[adapter_idx] = adapter_weight.weights_b + + use_sgmv = prefill or max_rank > BGMV_MAX_RANK + lora_a_ptr, lora_b_ptr = [], [] + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx in adapter_weights: + adapter_weight = adapter_weights[adapter_idx] lora_a_ptr.append( (adapter_weight.weights_a if use_sgmv else adapter_weight.weigths_a_t).data_ptr() ) From 16d9ebf381e0f5908972201a27a9e65661c4e346 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 6 Jun 2024 14:29:48 -0700 Subject: [PATCH 03/12] refactor(lora): fix adapter config issue --- server/lorax_server/adapters/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 54069c3a0..4ceed47d8 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -258,7 +258,7 @@ def load( adapter_to_segment[adapter_idx] = segment_idx if adapter_idx in adapter_weights: adapter_weight = adapter_weights[adapter_idx] - adapter_index_configs[adapter_idx] = adapter_weight.config + adapter_index_configs[adapter_idx] = adapter_weight.adapter_config max_rank = max(max_rank, adapter_weight.lora_a_r) rank_indices[adapter_weight.lora_a_r].append(segment_idx) lora_a[adapter_idx] = adapter_weight.weights_a From 1541517303ed61aa0740f4cd79c0217b80d874ce Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 6 Jun 2024 17:14:06 -0700 Subject: [PATCH 04/12] refactor(lora): fix weights spelling --- server/lorax_server/adapters/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 4ceed47d8..2ffdfa549 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -270,10 +270,10 @@ def load( if adapter_idx in adapter_weights: adapter_weight = adapter_weights[adapter_idx] lora_a_ptr.append( - (adapter_weight.weights_a if use_sgmv else adapter_weight.weigths_a_t).data_ptr() + (adapter_weight.weights_a if use_sgmv else adapter_weight.weights_a_t).data_ptr() ) lora_b_ptr.append( - (adapter_weight.weights_b if use_sgmv else adapter_weight.weigths_b_t).data_ptr() + (adapter_weight.weights_b if use_sgmv else adapter_weight.weights_b_t).data_ptr() ) else: lora_a_ptr.append(EMPTY_TENSOR.data_ptr()) From 6e3cbc07ce30319fd1dd613dd7f20bbdcff5552c Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 7 Jun 2024 12:08:08 -0700 Subject: [PATCH 05/12] refactor(lora): remove the `rank_indices` variable It's the same as the one used in outer for loop which can cause confusion --- server/lorax_server/adapters/lora.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 2ffdfa549..340ef29fe 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -310,9 +310,8 @@ def load( segment_starts[i] = prefill_head_segment_starts[segment_index] segment_ends[i] = prefill_head_segment_ends[segment_index] else: - rank_indices = set(indices) batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()] - batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices] + batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices] batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device) rank_data[rank] = RankSegments( From a7ee1c50e196aaf448a7539a21a3a70ca44f3f04 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 7 Jun 2024 18:11:28 -0700 Subject: [PATCH 06/12] feat(embed_tokens): Support `embed_tokens` as a target module --- server/lorax_server/adapters/lora.py | 33 ++-- server/lorax_server/adapters/weights.py | 7 +- .../custom_modeling/flash_llama_modeling.py | 11 +- server/lorax_server/models/flash_llama.py | 7 +- server/lorax_server/utils/adapter.py | 10 ++ server/lorax_server/utils/graph.py | 2 + server/lorax_server/utils/layers.py | 146 ++++++++++++++++++ server/lorax_server/utils/lora.py | 1 + 8 files changed, 200 insertions(+), 17 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 340ef29fe..04a856017 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -8,6 +8,7 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import LORA +from lorax_server.utils.lora import EMBED_TOKENS from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.utils.sgmv import ( BGMV_MAX_RANK, @@ -40,14 +41,20 @@ def map_weights_for_model( adapter_weight_names = set() module_map = {} for weight_name in weight_names: - lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" - lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if EMBED_TOKENS in weight_name: + lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A" + lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B" + else: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: continue + # note(ajinkya): popping the weights so that we know which weights are + # can be used as lora weights (supported) and which cannot module_map[weight_name] = { - "lora_A": (adapter_weights[lora_a_name], lora_a_name), - "lora_B": (adapter_weights[lora_b_name], lora_b_name), + "lora_A": (adapter_weights.pop(lora_a_name), lora_a_name), + "lora_B": (adapter_weights.pop(lora_b_name), lora_b_name), } adapter_weight_names.add(lora_a_name) adapter_weight_names.add(lora_b_name) @@ -90,6 +97,7 @@ def __init__( weights_a: List[torch.Tensor], weights_b: List[torch.Tensor], adapter_config: LoraConfig, + is_embed: bool = False, ): self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 @@ -98,7 +106,8 @@ def __init__( self._is_transposed = False # [num_layers, hidden_size, r] - weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + if not is_embed: + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] self._weights_a = torch.stack(weights_a) # [num_layers, r, hidden_size] @@ -158,8 +167,10 @@ def load( key = (layer_id, layer_type) weight_name, layer = model.target_to_layer[key] - base_weight = layer.base_layer.linear.weight - base_device = base_weight.device + if EMBED_TOKENS in layer_type: + base_device = layer.base_layer.weight.device + else: + base_device = layer.base_layer.linear.weight.device if weight_name not in module_map: # There is no LoRA weight for this layer type in the adapter @@ -196,13 +207,15 @@ def load( return LoraWeights( *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), - config, + adapter_config=config, + is_embed=(layer_type == EMBED_TOKENS), ) @dataclass class RankSegments: rank: int + adapter_index_map: int lora_a_ptr: torch.Tensor lora_b_ptr: torch.Tensor @@ -242,6 +255,7 @@ def load( meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor], + is_embed: bool, ) -> Optional["BatchLoraWeights"]: adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)} @@ -270,7 +284,7 @@ def load( if adapter_idx in adapter_weights: adapter_weight = adapter_weights[adapter_idx] lora_a_ptr.append( - (adapter_weight.weights_a if use_sgmv else adapter_weight.weights_a_t).data_ptr() + (adapter_weight.weights_a if use_sgmv or is_embed else adapter_weight.weights_a_t).data_ptr() ) lora_b_ptr.append( (adapter_weight.weights_b if use_sgmv else adapter_weight.weights_b_t).data_ptr() @@ -316,6 +330,7 @@ def load( rank_data[rank] = RankSegments( rank=rank, + adapter_index_map=indices, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, lora_a_ptr=lora_a_ptr[indices], diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 0468baaa8..5d9060f96 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -6,7 +6,7 @@ import torch from lorax_server.adapters.types import LORA -from lorax_server.utils.lora import LM_HEAD +from lorax_server.utils.lora import LM_HEAD, EMBED_TOKENS @dataclass @@ -82,6 +82,7 @@ def get_data( meta: AdapterBatchMetadata, prefill: bool, prefill_head_indices: Optional[torch.Tensor], + is_embed: bool, ) -> Dict[str, BatchAdapterWeights]: # bucket adapters by batch class adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict) @@ -91,7 +92,7 @@ def get_data( batch_data = {} for batch_type, adapter_weights in adapter_batch_types.items(): - batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices) + batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices, is_embed) if batched_weights is not None: batch_data[batch_type.key()] = batched_weights return batch_data @@ -117,7 +118,7 @@ def from_meta( for k, v in weights.items(): if v.is_empty(): continue - data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None) + data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None, k == EMBED_TOKENS) return AdapterBatchData(meta=meta, data=data, prefill=prefill) def ranks(self) -> Set[int]: diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 5f30a4676..899c11acc 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -39,6 +39,7 @@ TensorParallelHead, TensorParallelMultiAdapterLinear, TensorParallelRowLinear, + TensorParallelAdapterRowEmbedding, get_linear, ) from lorax_server.utils.lora import ( @@ -50,6 +51,7 @@ Q_PROJ, UP_PROJ, V_PROJ, + EMBED_TOKENS, ) @@ -457,7 +459,12 @@ def __init__(self, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) + self.embed_tokens = TensorParallelAdapterRowEmbedding( + base_layer=TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights), + layer_id=0, + layer_name=EMBED_TOKENS, + process_group=process_group, + ) self.layers = nn.ModuleList( [ FlashLlamaLayer( @@ -488,7 +495,7 @@ def forward( max_s: int, adapter_data: AdapterBatchData, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.embed_tokens(input_ids, adapter_data) # Get rotary cos and sin for this forward # Avoid to index in each layer diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 337169b6b..bd65191f4 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -24,13 +24,13 @@ Q_PROJ, UP_PROJ, V_PROJ, + EMBED_TOKENS, ) tracer = trace.get_tracer(__name__) -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs -ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD +ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD, EMBED_TOKENS] ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} @@ -136,6 +136,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + layer_weights[(0, EMBED_TOKENS)] = ("model.embed_tokens", self.model.model.embed_tokens) return layer_weights @property @@ -147,7 +148,7 @@ def default_traced_adapter_layers(self) -> List[str]: return [Q_PROJ, V_PROJ] def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + return 1 if layer_type == LM_HEAD or layer_type == EMBED_TOKENS else len(self.model.model.layers) def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 97322154d..8d4a076a6 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -3,6 +3,7 @@ from functools import lru_cache from typing import TYPE_CHECKING, Set, Tuple +from loguru import logger from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer @@ -158,4 +159,13 @@ def load_module_map( # map the model weights to the relevant adapter weights (LoRA A and B matrices) module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) + + # note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be + # supporting all the weights in the adapter which should be an error but for now just logging it + if len(adapter_weights) > 0: + logger.warning( + f"Adapter {adapter_id} for the model {model_id}" + \ + f" contains unsupported weights: {', '.join(adapter_weights.keys())}" + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index a6e1f3c97..12e6cbd70 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -109,6 +109,7 @@ def get_max_graph_state( rank_data={ MAX_RANK: RankSegments( rank=MAX_RANK, + adapter_index_map=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device), @@ -193,6 +194,7 @@ def trace( { max_rank: RankSegments( rank=max_rank, + adapter_index_map=weight_data.rank_data[MAX_RANK].adapter_index_map[:batch_size], lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size], lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size], indices=weight_data.rank_data[MAX_RANK].indices[:batch_size], diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index dd4bd6b66..608a1adad 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -349,6 +349,152 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return out +class LoraEmbedding(nn.Module): + def __init__(self, layer_id, process_group): + super().__init__() + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = data.get(LORA) if data is not None else None + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if data.use_sgmv: + # Use SGMV for prefill + if lora_a_ptr is not None and lora_b_ptr is not None: + # note(ajinkya): loop through all segments for this rank + # and lookup embeddings in each lora `A` matrix. + v = torch.zeros_like(result[:, :r]) + for i in range(len(rank_segments.segment_starts)): + v[rank_segments.segment_starts[i]:rank_segments.segment_ends[i], :] = ( + torch.nn.functional.embedding( + input[rank_segments.segment_starts[i]:rank_segments.segment_ends[i]], + data.lora_a[rank_segments.adapter_index_map[i]][self.layer_id], + ) + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + if lora_a_ptr is not None and lora_b_ptr is not None: + # note(ajinkya): there's no segmentation in the batch so just loop + # through each sample in the batch, get the corresponding lora `A` + # matrix, and lookup embeddings + v = torch.zeros_like(result[:, :r]) + for i in range(input.shape[0]): + v[i, :] = torch.nn.functional.embedding( + input[i], + data.lora_a[rank_segments.indices[i].item()][self.layer_id] + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1) + layer_result = self.forward_lora(input, data, adapter_index, adapter_mask) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = torch.nn.functional.embedding(input, lora_a) + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelAdapterRowEmbedding(LoraEmbedding): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(layer_id, process_group) + self.base_layer = base_layer + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torch.Tensor: + result = self.base_layer(input) + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type(result, input, adapter_data, self.layer_name, start_idx, end_idx) + + return result + + # def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # # + # # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + # torch.distributed.all_reduce(a_out, group=self.process_group) + # return a_out + + try: import dropout_layer_norm diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index f3d3e6f16..effff93c6 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -9,3 +9,4 @@ DOWN_PROJ = "down_proj" LM_HEAD = "lm_head" +EMBED_TOKENS = "embed_tokens" From b72957c8ddd90fbdbe3571074ae63bc29e07eb03 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 7 Jun 2024 18:58:21 -0700 Subject: [PATCH 07/12] bug(ruff): make ruff happy --- server/lorax_server/adapters/lora.py | 2 +- server/lorax_server/adapters/weights.py | 2 +- .../models/custom_modeling/flash_llama_modeling.py | 4 ++-- server/lorax_server/models/flash_llama.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 04a856017..8db0a4f86 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -8,8 +8,8 @@ from lorax_server.adapters.config import AdapterConfig, ModuleMap from lorax_server.adapters.types import LORA -from lorax_server.utils.lora import EMBED_TOKENS from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights +from lorax_server.utils.lora import EMBED_TOKENS from lorax_server.utils.sgmv import ( BGMV_MAX_RANK, MAX_RANK_CUSTOM, diff --git a/server/lorax_server/adapters/weights.py b/server/lorax_server/adapters/weights.py index 5d9060f96..ed4095636 100644 --- a/server/lorax_server/adapters/weights.py +++ b/server/lorax_server/adapters/weights.py @@ -6,7 +6,7 @@ import torch from lorax_server.adapters.types import LORA -from lorax_server.utils.lora import LM_HEAD, EMBED_TOKENS +from lorax_server.utils.lora import EMBED_TOKENS, LM_HEAD @dataclass diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 899c11acc..af9c0755a 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -33,17 +33,18 @@ from lorax_server.utils.layers import ( MultiAdapterHead, PositionRotaryEmbedding, + TensorParallelAdapterRowEmbedding, TensorParallelAdapterRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelHead, TensorParallelMultiAdapterLinear, TensorParallelRowLinear, - TensorParallelAdapterRowEmbedding, get_linear, ) from lorax_server.utils.lora import ( DOWN_PROJ, + EMBED_TOKENS, GATE_PROJ, K_PROJ, LM_HEAD, @@ -51,7 +52,6 @@ Q_PROJ, UP_PROJ, V_PROJ, - EMBED_TOKENS, ) diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index bd65191f4..71609b9c7 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -17,6 +17,7 @@ ) from lorax_server.utils.lora import ( DOWN_PROJ, + EMBED_TOKENS, GATE_PROJ, K_PROJ, LM_HEAD, @@ -24,7 +25,6 @@ Q_PROJ, UP_PROJ, V_PROJ, - EMBED_TOKENS, ) tracer = trace.get_tracer(__name__) From d40c52bad8cecf2518325da828ea4fb363d6b6a9 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 7 Jun 2024 19:47:39 -0700 Subject: [PATCH 08/12] bug(tests): fix lora test case --- server/tests/utils/test_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index aa1a98366..60338d3c6 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -48,7 +48,7 @@ def test_batched_lora_weights(lora_ranks: List[int]): ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None, is_embed=False).get(LORA) assert len(data.lora_a) == 2 assert data.lora_a.keys() == meta.adapter_set From f73317e533879fb7cd443875b3a500da7c9009dc Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 26 Jun 2024 11:44:20 -0700 Subject: [PATCH 09/12] refactor : lora.load function for clarity --- server/lorax_server/adapters/lora.py | 54 +++++++++++++++------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 8db0a4f86..9b3dda5ea 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -278,22 +278,10 @@ def load( lora_a[adapter_idx] = adapter_weight.weights_a lora_b[adapter_idx] = adapter_weight.weights_b + if not max_rank: + return None + use_sgmv = prefill or max_rank > BGMV_MAX_RANK - lora_a_ptr, lora_b_ptr = [], [] - for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx in adapter_weights: - adapter_weight = adapter_weights[adapter_idx] - lora_a_ptr.append( - (adapter_weight.weights_a if use_sgmv or is_embed else adapter_weight.weights_a_t).data_ptr() - ) - lora_b_ptr.append( - (adapter_weight.weights_b if use_sgmv else adapter_weight.weights_b_t).data_ptr() - ) - else: - lora_a_ptr.append(EMPTY_TENSOR.data_ptr()) - lora_b_ptr.append(EMPTY_TENSOR.data_ptr()) - lora_a_ptr = torch.tensor(lora_a_ptr, dtype=torch.int64, device=device) - lora_b_ptr = torch.tensor(lora_b_ptr, dtype=torch.int64, device=device) if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] @@ -313,28 +301,46 @@ def load( segment_starts = None segment_ends = None batch_indices = None + lora_a_ptr_indices = [] + lora_b_ptr_indices = [] if use_sgmv: - lora_a_ptr_indices = lora_a_ptr[indices] + for segment_idx in indices: + adapter_weight = adapter_weights[segment_indices[segment_idx]] + lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr()) tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: - for i, segment_index in enumerate(indices): - segment_starts[i] = prefill_head_segment_starts[segment_index] - segment_ends[i] = prefill_head_segment_ends[segment_index] + for i, segment_idx in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_idx] + segment_ends[i] = prefill_head_segment_ends[segment_idx] else: - batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()] - batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices] - batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device) + adapter_indices = {} + for segment_idx in indices: + adapter_idx = segment_indices[segment_idx] + adapter_weight = adapter_weights[adapter_idx] + if adapter_idx not in adapter_indices: + lora_a_ptr_indices.append( + (adapter_weight.weights_a if is_embed else adapter_weight.weights_a_t).data_ptr() + ) + lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr()) + adapter_indices[adapter_idx] = len(lora_a_ptr_indices) - 1 + batch_indices = torch.tensor([ + adapter_indices.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() + ], dtype=torch.int64, device=device) + + lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device) + lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device) rank_data[rank] = RankSegments( rank=rank, adapter_index_map=indices, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr[indices], - lora_b_ptr=lora_b_ptr[indices], + lora_a_ptr=lora_a_ptr_indices, + lora_b_ptr=lora_b_ptr_indices, segment_starts=segment_starts, segment_ends=segment_ends, indices=batch_indices, From 891049fd169e5b5298e3d4de9d92ba26dbcf16d8 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 26 Jun 2024 12:00:45 -0700 Subject: [PATCH 10/12] test: fix tests and refactor a bit more --- server/lorax_server/adapters/lora.py | 28 +++++++++++++++++++++------- server/tests/utils/test_lora.py | 17 ++++++++++------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 9b3dda5ea..1b3e65921 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -266,10 +266,9 @@ def load( device = first_weights.weights_a.device segment_indices = meta.segment_indices - lora_a, lora_b, adapter_index_configs, adapter_to_segment = {}, {}, {}, {} + lora_a, lora_b, adapter_index_configs = {}, {}, {} max_rank, rank_indices = 0, defaultdict(list) for segment_idx, adapter_idx in enumerate(segment_indices): - adapter_to_segment[adapter_idx] = segment_idx if adapter_idx in adapter_weights: adapter_weight = adapter_weights[adapter_idx] adapter_index_configs[adapter_idx] = adapter_weight.adapter_config @@ -285,11 +284,18 @@ def load( if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + # prefill_head_indices is used to slice the tokens in the batch such + # that we only forward the last token for each request through lm_head + # there can be multiple head_index associated with each adapter segment for head_index in prefill_head_indices: # j cannot go out of bounds as that would mean there are tokens without segments if head_index < meta.adapter_segments[j]: + # head_index is part of the current adapter + # so increment the current segment end prefill_head_segment_ends[-1] += 1 else: + # head_index in not part of the current adapter + # close the previous segment and start a new one prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) j += 1 @@ -309,26 +315,34 @@ def load( adapter_weight = adapter_weights[segment_indices[segment_idx]] lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr()) lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr()) - tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: + # since prefill_head_indices is present the segment starts and ends + # need to be adjusted according to the number of head tokens in each for i, segment_idx in enumerate(indices): segment_starts[i] = prefill_head_segment_starts[segment_idx] segment_ends[i] = prefill_head_segment_ends[segment_idx] else: - adapter_indices = {} + adapter_idx_to_pointer_idx = {} + # find out which adapters are present in the segments for this rank + # iterate over each segment index and use it to find adapter index and weights for segment_idx in indices: adapter_idx = segment_indices[segment_idx] adapter_weight = adapter_weights[adapter_idx] - if adapter_idx not in adapter_indices: + # if the adapter hasn't been seen before, then append its weight pointers + # and save the index to the just added pointers for later + if adapter_idx not in adapter_idx_to_pointer_idx: lora_a_ptr_indices.append( (adapter_weight.weights_a if is_embed else adapter_weight.weights_a_t).data_ptr() ) lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr()) - adapter_indices[adapter_idx] = len(lora_a_ptr_indices) - 1 + adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1 + # for each token in the batch, see if its adapter is present in the segments for this rank + # if present, then store the index of its weight pointers otherwise store -1 batch_indices = torch.tensor([ - adapter_indices.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() + adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() ], dtype=torch.int64, device=device) lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index b7e1edc01..116b0b9a1 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -37,6 +37,7 @@ def load( meta: "AdapterBatchMetadata", prefill: bool, prefill_head_indices: torch.Tensor, + is_embed: bool, ) -> Optional["BatchAdapterWeights"]: return None @@ -107,10 +108,11 @@ def test_batched_lora_weights(lora_ranks: List[int]): "lora_ranks,adapter_indices,expected", [ ( - [8, 8, 16], - [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch { - 8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), + # rank: (expected pointer tensor size, expected indices tensor in the rank data) + 8: (2, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]) } ), @@ -118,8 +120,8 @@ def test_batched_lora_weights(lora_ranks: List[int]): [4, 8, 16], [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], { - 4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), - 8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), + 4: (1, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), + 8: (1, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), @@ -135,6 +137,7 @@ def test_batched_lora_weights_decode( assert batched_weights.is_empty() h = 1024 + adapter_weights = [] for idx, lora_rank in enumerate(lora_ranks): weights = LoraWeights( weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], @@ -153,7 +156,7 @@ def test_batched_lora_weights_decode( ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None, is_embed=False).get(LORA) for lora_rank, rd in data.rank_data.items(): expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) @@ -197,6 +200,6 @@ def test_batched_lora_weights_no_segments(): ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): - data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) + data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None, is_embed=False).get(LORA) print(data) From 63019b95f5b3fd4e3524bb89ab61e7e5e4084268 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 26 Jun 2024 15:07:28 -0700 Subject: [PATCH 11/12] tests: expand test cases to check for correct adapter pointers --- server/tests/utils/test_lora.py | 120 +++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 116b0b9a1..d30c51c6e 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -9,6 +9,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +from lorax_server.utils.segments import find_segments class FakeAdapterWeights(AdapterWeights): @@ -43,16 +44,37 @@ def load( @pytest.mark.parametrize( - "lora_ranks", + "lora_ranks,adapter_indices,expected", [ - [8, 16], - [32, 64], + ( + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch + { + 8: ( # rank + [0, 2, 4, 6], # expected segment starts + [2, 4, 6, 8], # expected segment ends + [0, 1, 0, 1], # expected adapter indices + ), + 16: ([8], [10], [2]), + } + ), + ( + [4, 8, 16], + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + { + 4: ([0, 4], [2, 6], [0, 0]), + 8: ([2, 6], [4, 8], [1, 1]), + 16: ([8], [10], [2]), + } + ), ], ) -def test_batched_lora_weights(lora_ranks: List[int]): - # batch meta is hardcoded with this assumption below - assert len(lora_ranks) == 2 - +def test_batched_lora_weights( + lora_ranks: List[int], + adapter_indices: List[int], + expected: Dict[int, Tuple[List[int], Tuple[int], Tuple[int]]] +): + num_adapters = len(lora_ranks) batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -69,39 +91,50 @@ def test_batched_lora_weights(lora_ranks: List[int]): batched_weights.add_adapter(idx, weights) assert not batched_weights.is_empty() - assert len(batched_weights.adapter_weights) == 2 + assert len(batched_weights.adapter_weights) == num_adapters + + segments, segment_indices = find_segments(adapter_indices) meta = AdapterBatchMetadata( - adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), - adapter_set={0, 1}, - adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), - segment_indices=[0, 1, 0, 1], + adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_set=set(adapter_indices), + adapter_segments=torch.tensor(segments, dtype=torch.int64), + segment_indices=segment_indices, ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None, is_embed=False).get(LORA) - assert len(data.lora_a) == 2 + assert len(data.lora_a) == num_adapters + assert len(data.lora_b) == num_adapters assert data.lora_a.keys() == meta.adapter_set - assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) - assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) - - assert len(data.lora_b) == 2 assert data.lora_b.keys() == meta.adapter_set - assert data.lora_b[0].shape == (1, lora_ranks[0], h) - assert data.lora_b[1].shape == (1, lora_ranks[1], h) + for i in range(num_adapters): + assert data.lora_a[i].shape == ( + (1, h, lora_ranks[i]) if lora_ranks[i] < MIN_RANK_CUSTOM else (1, lora_ranks[i], h) + ) + assert data.lora_b[i].shape == (1, lora_ranks[i], h) - assert len(data.rank_data) == 2 - assert data.rank_data.keys() == set(lora_ranks) for lora_rank, rd in data.rank_data.items(): assert rd.rank == lora_rank - - # shape in all cases is the number of segments with this rank - assert rd.lora_a_ptr.shape == (2,) - assert rd.lora_b_ptr.shape == (2,) - assert rd.segment_starts.shape == (2,) - assert rd.segment_ends.shape == (2,) - + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][2]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + + expected_segment_starts = torch.tensor( + expected[lora_rank][0], dtype=rd.segment_starts.dtype, device=rd.segment_starts.device + ) + expected_segment_ends = torch.tensor( + expected[lora_rank][1], dtype=rd.segment_ends.dtype, device=rd.segment_ends.device + ) + assert all(rd.segment_ends == expected_segment_ends) + assert all(rd.segment_starts == expected_segment_starts) @pytest.mark.parametrize( @@ -111,18 +144,20 @@ def test_batched_lora_weights(lora_ranks: List[int]): [8, 8, 16], # ranks of adapters [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch { - # rank: (expected pointer tensor size, expected indices tensor in the rank data) - 8: (2, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]) + 8: ( # rank + [0, 1], # expected adapter indices + [0, 0, 1, 1, 0, 0, 1, 1, -1, -1] # expected indices + ), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ( [4, 8, 16], [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], { - 4: (1, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), - 8: (1, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), + 4: ([0], [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), + 8: ([1], [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ], @@ -130,14 +165,12 @@ def test_batched_lora_weights(lora_ranks: List[int]): def test_batched_lora_weights_decode( lora_ranks: List[int], adapter_indices: List[int], - expected: Dict[int, Tuple[int, List[int]]] + expected: Dict[int, Tuple[List[int], List[int]]] ): - from lorax_server.utils.segments import find_segments batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() h = 1024 - adapter_weights = [] for idx, lora_rank in enumerate(lora_ranks): weights = LoraWeights( weights_a=[torch.randn((h, lora_rank), dtype=torch.float16)], @@ -159,10 +192,19 @@ def test_batched_lora_weights_decode( data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None, is_embed=False).get(LORA) for lora_rank, rd in data.rank_data.items(): + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][0]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a_t.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b_t.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) - assert rd.lora_a_ptr.shape == (expected[lora_rank][0],) - assert rd.lora_b_ptr.shape == (expected[lora_rank][0],) assert all(rd.indices == expected_indices) + assert rd.segment_starts == None assert rd.segment_ends == None assert rd.tmp_shrink == None From 65a4c6ae17e71f231c7ba0b8d8a06225a154831a Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 26 Jun 2024 17:45:18 -0700 Subject: [PATCH 12/12] refactor : incorporate suggestions from PR review 1. Make embedding weight name a property of the model 2. Do not pop the adapter weight names 3. Uncomment collect_lora method --- server/lorax_server/adapters/config.py | 1 + server/lorax_server/adapters/lora.py | 7 ++++--- server/lorax_server/adapters/medusa.py | 1 + server/lorax_server/adapters/medusa_lora.py | 13 +++++++++++-- server/lorax_server/models/flash_llama.py | 4 ++++ server/lorax_server/models/model.py | 6 ++++++ server/lorax_server/utils/adapter.py | 18 +++++++++++++++--- server/lorax_server/utils/layers.py | 20 ++++++++++---------- server/tests/adapters/test_medusa.py | 2 +- 9 files changed, 53 insertions(+), 19 deletions(-) diff --git a/server/lorax_server/adapters/config.py b/server/lorax_server/adapters/config.py index 6bfcf8645..ec696b420 100644 --- a/server/lorax_server/adapters/config.py +++ b/server/lorax_server/adapters/config.py @@ -22,6 +22,7 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: pass diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 1b3e65921..aef75fe24 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -37,11 +37,12 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: adapter_weight_names = set() module_map = {} for weight_name in weight_names: - if EMBED_TOKENS in weight_name: + if embedding_weight_name in weight_name: lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A" lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B" else: @@ -53,8 +54,8 @@ def map_weights_for_model( # note(ajinkya): popping the weights so that we know which weights are # can be used as lora weights (supported) and which cannot module_map[weight_name] = { - "lora_A": (adapter_weights.pop(lora_a_name), lora_a_name), - "lora_B": (adapter_weights.pop(lora_b_name), lora_b_name), + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), } adapter_weight_names.add(lora_a_name) adapter_weight_names.add(lora_b_name) diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 476437128..bd65475a9 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -36,6 +36,7 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[ModuleMap, Set[str]]: # TODO(travis): this isn't technically the ModuleMap structure, make this more generic return adapter_weights, set(weight_names) diff --git a/server/lorax_server/adapters/medusa_lora.py b/server/lorax_server/adapters/medusa_lora.py index 833af0999..e9434f853 100644 --- a/server/lorax_server/adapters/medusa_lora.py +++ b/server/lorax_server/adapters/medusa_lora.py @@ -29,9 +29,18 @@ def map_weights_for_model( self, adapter_weights: Dict, weight_names: Tuple[str], + embedding_weight_name: str, ) -> Tuple[MedusaLoraModuleMap, Set[str]]: - lora_module_map, weight_names = self.lora_config.map_weights_for_model(adapter_weights, weight_names) - medusa_module_map, _ = self.medusa_config.map_weights_for_model(adapter_weights, weight_names) + lora_module_map, weight_names = self.lora_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name + ) + medusa_module_map, _ = self.medusa_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name + ) return MedusaLoraModuleMap(lora_module_map, medusa_module_map), weight_names def load_batched_adapter_weights( diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 71609b9c7..1a3d61927 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -143,6 +143,10 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return ADAPTER_LAYERS + @property + def embedding_weight_name(self) -> str: + return EMBED_TOKENS + @property def default_traced_adapter_layers(self) -> List[str]: return [Q_PROJ, V_PROJ] diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 3ee5bd55a..d66ed67f5 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -164,6 +164,11 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: def adapter_layers(self) -> List[str]: return [] + @property + def embedding_weight_name(self) -> str: + # setting it to '' will cause matches with any weight name + return 'placeholder value to be initialized by the subclass' + @property def default_traced_adapter_layers(self) -> List[str]: return [] @@ -224,6 +229,7 @@ def load_adapter( adapter_index, weight_names, api_token, + self.embedding_weight_name, self.trust_remote_code, ) diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 8d4a076a6..ad66449bd 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -41,11 +41,18 @@ def load_and_merge_adapters( adapter_index: int, weight_names: Tuple[str], api_token: str, + embedding_weight_name: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: if len(adapter_parameters.adapter_ids) == 1: return load_module_map( - model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token, trust_remote_code + model_id, + adapter_parameters.adapter_ids[0], + adapter_source, + weight_names, + api_token, + embedding_weight_name, + trust_remote_code, ) adapter_params = AdapterParametersContainer(adapter_parameters, adapter_source, adapter_index) @@ -133,6 +140,7 @@ def load_module_map( adapter_source: str, weight_names: Tuple[str], api_token: str, + embedding_weight_name: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: # TODO(geoffrey): refactor this and merge parts of this function with @@ -158,11 +166,15 @@ def load_module_map( adapter_weights.update(load_file(filename)) # map the model weights to the relevant adapter weights (LoRA A and B matrices) - module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, + weight_names, + embedding_weight_name, + ) # note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be # supporting all the weights in the adapter which should be an error but for now just logging it - if len(adapter_weights) > 0: + if len(set(adapter_weights.keys()) - set(adapter_weight_names)) > 0: logger.warning( f"Adapter {adapter_id} for the model {model_id}" + \ f" contains unsupported weights: {', '.join(adapter_weights.keys())}" diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 608a1adad..3456d9336 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -483,16 +483,16 @@ def forward(self, input: torch.Tensor, adapter_data: "AdapterBatchData") -> torc return result - # def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: - # # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. - # # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. - # # - # # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, - # # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same - # # rank, compute `a_out` on each, and then slice them into the buffer as shown here: - # # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 - # torch.distributed.all_reduce(a_out, group=self.process_group) - # return a_out + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out try: diff --git a/server/tests/adapters/test_medusa.py b/server/tests/adapters/test_medusa.py index ab0abe822..c6fd06f4d 100644 --- a/server/tests/adapters/test_medusa.py +++ b/server/tests/adapters/test_medusa.py @@ -16,7 +16,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM): download_adapter(adapter_id, HUB) module_map, medusa_config, _, _ = load_module_map( - model_id, adapter_id, HUB, tuple(), None + model_id, adapter_id, HUB, tuple(), None, default_causal_lm.embedding_weight_name ) assert isinstance(medusa_config, MedusaConfig)