Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the lora load function for clarity and simplicity #529

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 55 additions & 80 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,68 +252,36 @@ def load(
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}

segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights]
if not segment_ranks:
lora_a, lora_b, adapter_index_configs = {}, {}, {}
max_rank, rank_indices = 0, defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx in adapter_weights:
adapter_weight = adapter_weights[adapter_idx]
adapter_index_configs[adapter_idx] = adapter_weight.adapter_config
max_rank = max(max_rank, adapter_weight.lora_a_r)
rank_indices[adapter_weight.lora_a_r].append(segment_idx)
lora_a[adapter_idx] = adapter_weight.weights_a
lora_b[adapter_idx] = adapter_weight.weights_b

if not max_rank:
return None

max_rank = max(segment_ranks)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
use_sgmv = prefill or max_rank > BGMV_MAX_RANK

if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
# prefill_head_indices is used to slice the tokens in the batch such
# that we only forward the last token for each request through lm_head
# there can be multiple head_index associated with each adapter segment
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
# j cannot go out of bounds as that would mean there are tokens without segments
if head_index < meta.adapter_segments[j]:
# head_index is part of the current adapter
# so increment the current segment end
prefill_head_segment_ends[-1] += 1
else:
# head_index in not part of the current adapter
# close the previous segment and start a new one
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
Expand All @@ -325,44 +293,51 @@ def load(
segment_starts = None
segment_ends = None
batch_indices = None
lora_a_ptr_indices = []
lora_b_ptr_indices = []

if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device)
for segment_idx in indices:
adapter_weight = adapter_weights[segment_indices[segment_idx]]
lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr())
lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr())
tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
# since prefill_head_indices is present the segment starts and ends
# need to be adjusted according to the number of head tokens in each
for i, segment_idx in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_idx]
segment_ends[i] = prefill_head_segment_ends[segment_idx]
else:
# `indices` indexes the `segment_indices` which contains segment wise adapter index
# `lora_a_ptr` contains segment wise pointers to lora weights
# lengths of `lora_a_ptr` and `segment_indices` must be same
# `indices` will be used to slice the `lora_a_ptr` tensor
# first, find the mapping between adapter index and its location in the `indices` array
idx_locs = {}
for loc, idx in enumerate(indices):
# use the idx to find the adapter index
if segment_indices[idx] not in idx_locs:
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc
# second, iterate over the adapter index for each token and find its location in the `indices` array
batch_indices = torch.tensor(
[
idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1
for idx in meta.adapter_indices.tolist()
],
dtype=torch.int64,
device=device,
)
adapter_idx_to_pointer_idx = {}
# find out which adapters are present in the segments for this rank
# iterate over each segment index and use it to find adapter index and weights
for segment_idx in indices:
adapter_idx = segment_indices[segment_idx]
adapter_weight = adapter_weights[adapter_idx]
# if the adapter hasn't been seen before, then append its weight pointers
# and save the index to the just added pointers for later
if adapter_idx not in adapter_idx_to_pointer_idx:
lora_a_ptr_indices.append(adapter_weight.weights_a_t.data_ptr())
lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr())
adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1
# for each token in the batch, see if its adapter is present in the segments for this rank
# if present, then store the index of its weight pointers otherwise store -1
batch_indices = torch.tensor([
adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist()
], dtype=torch.int64, device=device)

lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device)
lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
lora_a_ptr=lora_a_ptr_indices,
lora_b_ptr=lora_b_ptr_indices,
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
Expand Down
122 changes: 83 additions & 39 deletions server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lorax_server.adapters.lora import LoraWeights
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights
from lorax_server.utils.segments import find_segments
from lorax_server.utils.sgmv import MIN_RANK_CUSTOM


Expand Down Expand Up @@ -42,16 +43,37 @@ def load(


@pytest.mark.parametrize(
"lora_ranks",
"lora_ranks,adapter_indices,expected",
[
[8, 16],
[32, 64],
(
[8, 8, 16], # ranks of adapters
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch
{
8: ( # rank
[0, 2, 4, 6], # expected segment starts
[2, 4, 6, 8], # expected segment ends
[0, 1, 0, 1], # expected adapter indices
),
16: ([8], [10], [2]),
}
),
(
[4, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
4: ([0, 4], [2, 6], [0, 0]),
8: ([2, 6], [4, 8], [1, 1]),
16: ([8], [10], [2]),
}
),
],
)
def test_batched_lora_weights(lora_ranks: List[int]):
# batch meta is hardcoded with this assumption below
assert len(lora_ranks) == 2

def test_batched_lora_weights(
lora_ranks: List[int],
adapter_indices: List[int],
expected: Dict[int, Tuple[List[int], Tuple[int], Tuple[int]]]
):
num_adapters = len(lora_ranks)
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()

Expand All @@ -68,69 +90,82 @@ def test_batched_lora_weights(lora_ranks: List[int]):
batched_weights.add_adapter(idx, weights)

assert not batched_weights.is_empty()
assert len(batched_weights.adapter_weights) == 2
assert len(batched_weights.adapter_weights) == num_adapters

segments, segment_indices = find_segments(adapter_indices)

meta = AdapterBatchMetadata(
adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64),
adapter_set={0, 1},
adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64),
segment_indices=[0, 1, 0, 1],
adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64),
adapter_set=set(adapter_indices),
adapter_segments=torch.tensor(segments, dtype=torch.int64),
segment_indices=segment_indices,
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA)

assert len(data.lora_a) == 2
assert len(data.lora_a) == num_adapters
assert len(data.lora_b) == num_adapters
assert data.lora_a.keys() == meta.adapter_set
assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h))
assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h))

assert len(data.lora_b) == 2
assert data.lora_b.keys() == meta.adapter_set
assert data.lora_b[0].shape == (1, lora_ranks[0], h)
assert data.lora_b[1].shape == (1, lora_ranks[1], h)
for i in range(num_adapters):
assert data.lora_a[i].shape == (
(1, h, lora_ranks[i]) if lora_ranks[i] < MIN_RANK_CUSTOM else (1, lora_ranks[i], h)
)
assert data.lora_b[i].shape == (1, lora_ranks[i], h)

assert len(data.rank_data) == 2
assert data.rank_data.keys() == set(lora_ranks)
for lora_rank, rd in data.rank_data.items():
assert rd.rank == lora_rank

# shape in all cases is the number of segments with this rank
assert rd.lora_a_ptr.shape == (2,)
assert rd.lora_b_ptr.shape == (2,)
assert rd.segment_starts.shape == (2,)
assert rd.segment_ends.shape == (2,)

expected_lora_a_ptr = []
expected_lora_b_ptr = []
for adapter_idx in expected[lora_rank][2]:
expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a.data_ptr())
expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b.data_ptr())
expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device)
expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device)
assert all(rd.lora_a_ptr == expected_lora_a_ptr)
assert all(rd.lora_b_ptr == expected_lora_b_ptr)

expected_segment_starts = torch.tensor(
expected[lora_rank][0], dtype=rd.segment_starts.dtype, device=rd.segment_starts.device
)
expected_segment_ends = torch.tensor(
expected[lora_rank][1], dtype=rd.segment_ends.dtype, device=rd.segment_ends.device
)
assert all(rd.segment_ends == expected_segment_ends)
assert all(rd.segment_starts == expected_segment_starts)


@pytest.mark.parametrize(
"lora_ranks,adapter_indices,expected",
[
(
[8, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
[8, 8, 16], # ranks of adapters
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch
{
8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0])
8: ( # rank
[0, 1], # expected adapter indices
[0, 0, 1, 1, 0, 0, 1, 1, -1, -1] # expected indices
),
16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]),
}
),
(
[4, 8, 16],
[0, 0, 1, 1, 0, 0, 1, 1, 2, 2],
{
4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]),
8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]),
16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]),
4: ([0], [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]),
8: ([1], [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]),
16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]),
}
),
],
)
def test_batched_lora_weights_decode(
lora_ranks: List[int],
adapter_indices: List[int],
expected: Dict[int, Tuple[int, List[int]]]
expected: Dict[int, Tuple[List[int], List[int]]]
):
from lorax_server.utils.segments import find_segments
batched_weights = LayerAdapterWeights()
assert batched_weights.is_empty()

Expand All @@ -156,10 +191,19 @@ def test_batched_lora_weights_decode(
data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA)

for lora_rank, rd in data.rank_data.items():
expected_lora_a_ptr = []
expected_lora_b_ptr = []
for adapter_idx in expected[lora_rank][0]:
expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a_t.data_ptr())
expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b_t.data_ptr())
expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device)
expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device)
assert all(rd.lora_a_ptr == expected_lora_a_ptr)
assert all(rd.lora_b_ptr == expected_lora_b_ptr)

expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device)
assert rd.lora_a_ptr.shape == (expected[lora_rank][0],)
assert rd.lora_b_ptr.shape == (expected[lora_rank][0],)
assert all(rd.indices == expected_indices)

assert rd.segment_starts is None
assert rd.segment_ends is None
assert rd.tmp_shrink is None
Expand Down
Loading