From 65d450a835d0b7f74e9b1d50ba9a66984c32a8f8 Mon Sep 17 00:00:00 2001 From: Timothy Wang Date: Tue, 29 Oct 2024 16:15:29 -0400 Subject: [PATCH 1/3] Allow adapter loading for VLMs --- server/lorax_server/models/vlm_causal_lm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 7c4a3b543..3fd344007 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -290,6 +290,10 @@ def batch_type(self) -> Type[VlmCausalLMBatch]: def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + @property + def supports_adapter_loading(self) -> bool: + return True + def forward( self, batch: VlmCausalLMBatch, From 220224daf953d12c05c5e9c676c368628e93589b Mon Sep 17 00:00:00 2001 From: Timothy Wang Date: Tue, 29 Oct 2024 19:13:54 -0400 Subject: [PATCH 2/3] Add adapter_target_to_layer --- server/lorax_server/models/vlm_causal_lm.py | 55 ++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 3fd344007..d6f658721 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Iterable, List, Optional, Tuple, Type +from typing import Dict, Iterable, List, Optional, Tuple, Type import torch import torch.distributed @@ -19,12 +19,22 @@ from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.state import PREFIX_CACHING from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import LM_HEAD tracer = trace.get_tracer(__name__) IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +LANGUAGE_ATTN_Q_PROJ = "self_attn.language.q_proj" +LANGUAGE_ATTN_K_PROJ = "self_attn.language.k_proj" +LANGUAGE_ATTN_V_PROJ = "self_attn.language.v_proj" +LANGUAGE_ATTN_O_PROJ = "self_attn.language.out_proj" +VISION_ATTN_Q_PROJ = "self_attn.vision.q_proj" +VISION_ATTN_K_PROJ = "self_attn.vision.k_proj" +VISION_ATTN_V_PROJ = "self_attn.vision.v_proj" +VISION_ATTN_O_PROJ = "self_attn.vision.out_proj" + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -293,6 +303,49 @@ def max_past(self) -> Optional[int]: @property def supports_adapter_loading(self) -> bool: return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + language_prefix = "text_model.model.layers" + vision_prefix = "vision_tower.encoder.layers" + for i, layer in enumerate(self.model.text_model.model.layers): + layer_weights[(i, LANGUAGE_ATTN_K_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, LANGUAGE_ATTN_V_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, LANGUAGE_ATTN_O_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.out_proj", + layer.self_attn.o_proj, + ) + layer_weights[(i, LANGUAGE_ATTN_Q_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + for i, layer in enumerate(self.model.vision_tower.encoder.layers): + layer_weights[(i, VISION_ATTN_K_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.k_proj", + layer.self_attn.k_proj, + ) + layer_weights[(i, VISION_ATTN_V_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.v_proj", + layer.self_attn.v_proj, + ) + layer_weights[(i, VISION_ATTN_O_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.out_proj", + layer.self_attn.out_proj, + ) + layer_weights[(i, VISION_ATTN_Q_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.q_proj", + layer.self_attn.q_proj, + ) + + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.text_model.lm_head) + return layer_weights + def forward( self, From b5b5d1de950115043eae546a360007df1f960711 Mon Sep 17 00:00:00 2001 From: Timothy Wang Date: Tue, 29 Oct 2024 20:26:23 -0400 Subject: [PATCH 3/3] Fixes for mistral --- .../models/custom_modeling/flash_mistral_modeling.py | 1 + server/lorax_server/models/vlm_causal_lm.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 88b69dbb5..83a43e848 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -514,6 +514,7 @@ def forward( max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index d6f658721..bc7bbfd05 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -306,8 +306,8 @@ def supports_adapter_loading(self) -> bool: def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} - language_prefix = "text_model.model.layers" - vision_prefix = "vision_tower.encoder.layers" + language_prefix = "language_model.model.layers" + vision_prefix = "vision_tower.vision_model.encoder.layers" for i, layer in enumerate(self.model.text_model.model.layers): layer_weights[(i, LANGUAGE_ATTN_K_PROJ)] = ( f"{language_prefix}.{i}.self_attn.k_proj", @@ -328,11 +328,11 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: for i, layer in enumerate(self.model.vision_tower.encoder.layers): layer_weights[(i, VISION_ATTN_K_PROJ)] = ( f"{vision_prefix}.{i}.self_attn.k_proj", - layer.self_attn.k_proj, + layer.self_attn.qkv, ) layer_weights[(i, VISION_ATTN_V_PROJ)] = ( f"{vision_prefix}.{i}.self_attn.v_proj", - layer.self_attn.v_proj, + layer.self_attn.qkv, ) layer_weights[(i, VISION_ATTN_O_PROJ)] = ( f"{vision_prefix}.{i}.self_attn.out_proj", @@ -340,7 +340,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: ) layer_weights[(i, VISION_ATTN_Q_PROJ)] = ( f"{vision_prefix}.{i}.self_attn.q_proj", - layer.self_attn.q_proj, + layer.self_attn.qkv, ) layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.text_model.lm_head)