From bffb5a00d1e2d183d7f5d139db149f6beb5c717d Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Tue, 17 Dec 2024 11:34:35 -0700 Subject: [PATCH] [fix] fix issue with BloomBlock due to transformers upgrade (#2640) --- .../djl_python/seq_scheduler/lm_block.py | 28 ++++--------------- engines/python/setup/setup.py | 4 +-- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/engines/python/setup/djl_python/seq_scheduler/lm_block.py b/engines/python/setup/djl_python/seq_scheduler/lm_block.py index 46aa1fe4e..5aafcb589 100644 --- a/engines/python/setup/djl_python/seq_scheduler/lm_block.py +++ b/engines/python/setup/djl_python/seq_scheduler/lm_block.py @@ -15,6 +15,7 @@ from typing import Tuple, Union import torch +from transformers import DynamicCache class LMBlock(ABC): @@ -107,34 +108,17 @@ def forward(self, input_ids: torch.tensor, position_ids: torch.tensor, # Pre-process if past_key_values is not None: - _, num_head, seq_len, kv_dim = past_key_values[0][0].shape - new_kv_list = [] - for k, v in past_key_values: - k_new = torch.permute( - k.view(batch_size * num_head, seq_len, kv_dim), (0, 2, 1)) - v_new = v.view(batch_size * num_head, seq_len, kv_dim) - new_kv_list.append((k_new, v_new)) - past_key_values = tuple(new_kv_list) + cache = DynamicCache.from_legacy_cache(past_key_values) + else: + cache = DynamicCache() # Forward output = self.model.forward(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, - past_key_values=past_key_values, + past_key_values=cache, **self.config) - past_key_values = output.past_key_values - - # Post-process - _, kv_dim, seq_len = past_key_values[0][0].shape - new_kv_list = [] - for k, v in past_key_values: - k_new = torch.permute(k, (0, 2, 1)).view(batch_size, -1, seq_len, - kv_dim) - v_new = v.view(batch_size, -1, seq_len, kv_dim) - new_kv_list.append((k_new, v_new)) - past_key_values = tuple(new_kv_list) - output.past_key_values = past_key_values - + output.past_key_values = output.past_key_values.to_legacy_cache() return output diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 2daa463d1..6883bbf89 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -56,8 +56,8 @@ def run(self): requirements = ['psutil', 'packaging', 'wheel'] test_requirements = [ - 'numpy<2', 'requests', 'Pillow', 'transformers==4.43.4', 'torch', - 'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', + 'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops', + 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', 'pydantic>=2.0', "objgraph" ]