Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Support targeting the embedding layer for LoRA #501

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions server/lorax_server/adapters/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
embedding_weight_name: str,
) -> Tuple[ModuleMap, Set[str]]:
pass

Expand Down
157 changes: 77 additions & 80 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.lora import EMBED_TOKENS
from lorax_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
Expand Down Expand Up @@ -36,15 +37,22 @@ def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
embedding_weight_name: str,
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if embedding_weight_name in weight_name:
lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A"
lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B"
else:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue

# note(ajinkya): popping the weights so that we know which weights are
# can be used as lora weights (supported) and which cannot
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
Expand Down Expand Up @@ -90,6 +98,7 @@ def __init__(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
is_embed: bool = False,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
Expand All @@ -98,7 +107,8 @@ def __init__(
self._is_transposed = False

# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
if not is_embed:
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
Expand Down Expand Up @@ -158,8 +168,10 @@ def load(
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]

base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if EMBED_TOKENS in layer_type:
base_device = layer.base_layer.weight.device
else:
base_device = layer.base_layer.linear.weight.device

if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
Expand Down Expand Up @@ -196,13 +208,15 @@ def load(

return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
config,
adapter_config=config,
is_embed=(layer_type == EMBED_TOKENS),
)


@dataclass
class RankSegments:
rank: int
adapter_index_map: int

lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
Expand Down Expand Up @@ -242,6 +256,7 @@ def load(
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
is_embed: bool,
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}
Expand All @@ -252,68 +267,36 @@ def load(
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}

segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights]
if not segment_ranks:
lora_a, lora_b, adapter_index_configs = {}, {}, {}
max_rank, rank_indices = 0, defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx in adapter_weights:
adapter_weight = adapter_weights[adapter_idx]
adapter_index_configs[adapter_idx] = adapter_weight.adapter_config
max_rank = max(max_rank, adapter_weight.lora_a_r)
rank_indices[adapter_weight.lora_a_r].append(segment_idx)
lora_a[adapter_idx] = adapter_weight.weights_a
lora_b[adapter_idx] = adapter_weight.weights_b

if not max_rank:
return None

max_rank = max(segment_ranks)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
use_sgmv = prefill or max_rank > BGMV_MAX_RANK

if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
# prefill_head_indices is used to slice the tokens in the batch such
# that we only forward the last token for each request through lm_head
# there can be multiple head_index associated with each adapter segment
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
# j cannot go out of bounds as that would mean there are tokens without segments
if head_index < meta.adapter_segments[j]:
# head_index is part of the current adapter
# so increment the current segment end
prefill_head_segment_ends[-1] += 1
else:
# head_index in not part of the current adapter
# close the previous segment and start a new one
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
Expand All @@ -325,40 +308,54 @@ def load(
segment_starts = None
segment_ends = None
batch_indices = None
lora_a_ptr_indices = []
lora_b_ptr_indices = []

if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)
for segment_idx in indices:
adapter_weight = adapter_weights[segment_indices[segment_idx]]
lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr())
lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr())
tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
# since prefill_head_indices is present the segment starts and ends
# need to be adjusted according to the number of head tokens in each
for i, segment_idx in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_idx]
segment_ends[i] = prefill_head_segment_ends[segment_idx]
else:
# `indices` indexes the `segment_indices` which contains segment wise adapter index
# `lora_a_ptr` contains segment wise pointers to lora weights
# lengths of `lora_a_ptr` and `segment_indices` must be same
# `indices` will be used to slice the `lora_a_ptr` tensor
# first, find the mapping between adapter index and its location in the `indices` array
idx_locs = {}
for loc, idx in enumerate(indices):
# use the idx to find the adapter index
if segment_indices[idx] not in idx_locs:
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc
# second, iterate over the adapter index for each token and find its location in the `indices` array
adapter_idx_to_pointer_idx = {}
# find out which adapters are present in the segments for this rank
# iterate over each segment index and use it to find adapter index and weights
for segment_idx in indices:
adapter_idx = segment_indices[segment_idx]
adapter_weight = adapter_weights[adapter_idx]
# if the adapter hasn't been seen before, then append its weight pointers
# and save the index to the just added pointers for later
if adapter_idx not in adapter_idx_to_pointer_idx:
lora_a_ptr_indices.append(
(adapter_weight.weights_a if is_embed else adapter_weight.weights_a_t).data_ptr()
)
lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr())
adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1
# for each token in the batch, see if its adapter is present in the segments for this rank
# if present, then store the index of its weight pointers otherwise store -1
batch_indices = torch.tensor([
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist()
], dtype=torch.int64, device=device)

lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device)
lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
adapter_index_map=indices,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
lora_a_ptr=lora_a_ptr_indices,
lora_b_ptr=lora_b_ptr_indices,
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
embedding_weight_name: str,
) -> Tuple[ModuleMap, Set[str]]:
# TODO(travis): this isn't technically the ModuleMap structure, make this more generic
return adapter_weights, set(weight_names)
Expand Down
13 changes: 11 additions & 2 deletions server/lorax_server/adapters/medusa_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
embedding_weight_name: str,
) -> Tuple[MedusaLoraModuleMap, Set[str]]:
lora_module_map, weight_names = self.lora_config.map_weights_for_model(adapter_weights, weight_names)
medusa_module_map, _ = self.medusa_config.map_weights_for_model(adapter_weights, weight_names)
lora_module_map, weight_names = self.lora_config.map_weights_for_model(
adapter_weights,
weight_names,
embedding_weight_name
)
medusa_module_map, _ = self.medusa_config.map_weights_for_model(
adapter_weights,
weight_names,
embedding_weight_name
)
return MedusaLoraModuleMap(lora_module_map, medusa_module_map), weight_names

def load_batched_adapter_weights(
Expand Down
7 changes: 4 additions & 3 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from lorax_server.adapters.types import LORA
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.lora import EMBED_TOKENS, LM_HEAD


@dataclass
Expand Down Expand Up @@ -82,6 +82,7 @@ def get_data(
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
is_embed: bool,
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
Expand All @@ -91,7 +92,7 @@ def get_data(

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices)
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices, is_embed)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
Expand All @@ -117,7 +118,7 @@ def from_meta(
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None)
data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None, k == EMBED_TOKENS)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)

def ranks(self) -> Set[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
TensorParallelAdapterRowEmbedding,
TensorParallelAdapterRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
Expand All @@ -43,6 +44,7 @@
)
from lorax_server.utils.lora import (
DOWN_PROJ,
EMBED_TOKENS,
GATE_PROJ,
K_PROJ,
LM_HEAD,
Expand Down Expand Up @@ -457,7 +459,12 @@ def __init__(self, config, weights):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
self.embed_tokens = TensorParallelAdapterRowEmbedding(
base_layer=TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights),
layer_id=0,
layer_name=EMBED_TOKENS,
process_group=process_group,
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
Expand Down Expand Up @@ -488,7 +495,7 @@ def forward(
max_s: int,
adapter_data: AdapterBatchData,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.embed_tokens(input_ids, adapter_data)

# Get rotary cos and sin for this forward
# Avoid to index in each layer
Expand Down
Loading
Loading