Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[feature] sync patches from transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Aug 14, 2024
1 parent 4893c3a commit 23b8c54
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 129 deletions.
144 changes: 76 additions & 68 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,80 @@ def init_lora_layer_weight(
lora_linear.init_lora_weight(lora_config, (lora_a, lora_b))


def get_lora_layer_weight(
transformer_layer: LLMDecoder,
lora_config: LoraConfig,
lora_weights: Dict[str, torch.Tensor],
):
target_modules = lora_config.target_modules_
attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
attn_state_dict: Dict[str, torch.Tensor]
mlp_state_dict: Dict[str, torch.Tensor]
all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
all_state_dict.update(mlp_state_dict)
if isinstance(lora_config, MixLoraConfig):
model_prefix_name = "mixlora"
gate_layer_name = (
f"mixlora.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
)
moe_layer_name_list = list(mlp_state_dict.keys())
elif isinstance(lora_config, LoraMoeConfig):
model_prefix_name = "loramoe"
moe_layer_name_list = list(mlp_state_dict.keys())
elif isinstance(lora_config, MolaConfig):
model_prefix_name = "mola"
moe_layer_name_list = list(all_state_dict.keys())
else:
model_prefix_name = "base_model.model.model"
moe_layer_name_list = []

# for fused MoEs such as MixLoRA
mlp_moe_layer: LLMMoeBlock = transformer_layer.mlp_.moes_.get(
lora_config.adapter_name, None
)
if mlp_moe_layer is not None:
lora_weights[gate_layer_name] = mlp_moe_layer.gate_.weight

for proj_name, lora_linear in all_state_dict.items():
lora_linear: Linear
if proj_name not in target_modules or not target_modules[proj_name]:
continue
module_name = (
"self_attn"
if proj_name in attn_state_dict
else ("mlp" if proj_name in mlp_state_dict else None)
)
module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
if proj_name in moe_layer_name_list:
moe_layer = (
lora_linear.moes_[lora_config.adapter_name]
if lora_config.adapter_name in lora_linear.moes_
else mlp_moe_layer
)
# for plugged MoEs such as LoRAMoE, MoLA, etc.
if lora_config.adapter_name in lora_linear.moes_:
lora_weights[f"{module_name}.moe_gate.weight"] = lora_linear.moes_[
lora_config.adapter_name
].gate_.weight

for expert_idx in range(moe_layer.experts_):
moe_lora_name = f"moe.{lora_config.adapter_name}.experts.{expert_idx}"
lora_obj = lora_linear.loras_.get(moe_lora_name, None)
if lora_obj is not None:
lora_weights[
f"{module_name}.experts.{expert_idx}.lora_A.weight"
] = lora_obj.lora_a_.weight
lora_weights[
f"{module_name}.experts.{expert_idx}.lora_B.weight"
] = lora_obj.lora_b_.weight

else:
lora_obj = lora_linear.loras_.get(lora_config.adapter_name, None)
if lora_obj is not None:
lora_weights[f"{module_name}.lora_A.weight"] = lora_obj.lora_a_.weight
lora_weights[f"{module_name}.lora_B.weight"] = lora_obj.lora_b_.weight


class LLMModel(torch.nn.Module):
def __init__(self, model: LLMForCausalLM):
super().__init__()
Expand Down Expand Up @@ -539,75 +613,9 @@ def init_adapter(
def get_adapter_weight_dict(self, adapter_name: str) -> Dict[str, torch.Tensor]:
# return the lora weight and target_module's name
lora_weight_dict = self.output_.layers_[adapter_name].state_dict()
lora_config = self.adapter_configs_[adapter_name]
for transformer_layer in self.model_.layers_:
attn_state_dict, mlp_state_dict = transformer_layer.state_dict()
attn_state_dict: Dict[str, torch.Tensor]
mlp_state_dict: Dict[str, torch.Tensor]
all_state_dict: Dict[str, torch.Tensor] = copy.copy(attn_state_dict)
all_state_dict.update(mlp_state_dict)
if isinstance(self.adapter_configs_[adapter_name], MixLoraConfig):
model_prefix_name = "mixlora"
gate_layer_name = (
f"mixlora.layers.{transformer_layer.layer_id_}.mlp.moe_gate.weight"
)
moe_layer_name_list = list(mlp_state_dict.keys())
elif isinstance(self.adapter_configs_[adapter_name], LoraMoeConfig):
model_prefix_name = "loramoe"
moe_layer_name_list = list(mlp_state_dict.keys())
elif isinstance(self.adapter_configs_[adapter_name], MolaConfig):
model_prefix_name = "mola"
moe_layer_name_list = list(all_state_dict.keys())
else:
model_prefix_name = "base_model.model.model"
moe_layer_name_list = []

# for fused MoEs such as MixLoRA
mlp_moe_layer: LLMMoeBlock = transformer_layer.mlp_.moes_.get(
adapter_name, None
)
if mlp_moe_layer is not None:
lora_weight_dict[gate_layer_name] = mlp_moe_layer.gate_.weight

for proj_name, lora_linear in all_state_dict.items():
lora_linear: Linear
module_name = (
"self_attn"
if proj_name in attn_state_dict
else ("mlp" if proj_name in mlp_state_dict else None)
)
module_name = f"{model_prefix_name}.layers.{transformer_layer.layer_id_}.{module_name}.{proj_name}"
if proj_name in moe_layer_name_list:
moe_layer = (
lora_linear.moes_[adapter_name]
if adapter_name in lora_linear.moes_
else mlp_moe_layer
)
# for plugged MoEs such as LoRAMoE, MoLA, etc.
if adapter_name in lora_linear.moes_:
lora_weight_dict[f"{module_name}.moe_gate.weight"] = (
lora_linear.moes_[adapter_name].gate_.weight
)

for expert_idx in range(moe_layer.experts_):
moe_lora_name = f"moe.{adapter_name}.experts.{expert_idx}"
lora_obj = lora_linear.loras_.get(moe_lora_name, None)
if lora_obj is not None:
lora_weight_dict[
f"{module_name}.experts.{expert_idx}.lora_A.weight"
] = lora_obj.lora_a_.weight
lora_weight_dict[
f"{module_name}.experts.{expert_idx}.lora_B.weight"
] = lora_obj.lora_b_.weight

else:
lora_obj = lora_linear.loras_.get(adapter_name, None)
if lora_obj is not None:
lora_weight_dict[f"{module_name}.lora_A.weight"] = (
lora_obj.lora_a_.weight
)
lora_weight_dict[f"{module_name}.lora_B.weight"] = (
lora_obj.lora_b_.weight
)
get_lora_layer_weight(transformer_layer, lora_config, lora_weight_dict)

return lora_weight_dict

Expand Down
34 changes: 19 additions & 15 deletions mlora/models/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def forward(
key_states, value_states, self.layer_idx_, cache_kwargs
)

if attention_mask is not None:
seq_len = attention_mask.shape[1]
key_states = key_states[:, :, :seq_len]
value_states = value_states[:, :, :seq_len]

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
Expand Down Expand Up @@ -314,19 +319,23 @@ def forward(
past_key_value: Optional[LLMCache] = None,
):
if (
self.config_.attn_implementation_ != "flash_attn"
and self.config_.use_sliding_window_
self.config_.use_sliding_window_
and self.is_sliding_
and attention_mask is not None
):
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.sliding_window_,
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window_ :]
if self.config_.attn_implementation_ == "flash_attn":
attention_mask = attention_mask[:, -self.sliding_window_ :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.sliding_window_,
)
attention_mask = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window_ :]

residual = hidden_states

Expand Down Expand Up @@ -460,11 +469,6 @@ def from_pretrained(
dtype_=llm_model.dtype,
)

if use_sliding_window and attn_impl != "flash_attn":
raise ValueError(
f"Can not use sliding window attention with {attn_impl} attention."
)

if model_config.pad_token_id_ is None:
model_config.pad_token_id_ = -1

Expand Down
5 changes: 0 additions & 5 deletions mlora/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,6 @@ def from_pretrained(
dtype_=llm_model.dtype,
)

if use_sliding_window and attn_impl != "flash_attn":
raise ValueError(
f"Can not use sliding window attention with {attn_impl} attention."
)

# compatible with qwen2
if isinstance(llm_config, modeling_qwen2.Qwen2Config):
llm_args.max_window_layers_ = llm_config.max_window_layers
Expand Down
5 changes: 0 additions & 5 deletions mlora/models/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,6 @@ def from_pretrained(
dtype_=llm_model.dtype,
)

if use_sliding_window and attn_impl != "flash_attn":
raise ValueError(
f"Can not use sliding window attention with {attn_impl} attention."
)

if llm_args.pad_token_id_ is None:
llm_args.pad_token_id_ = -1

Expand Down
7 changes: 4 additions & 3 deletions mlora/modules/abstracts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from abc import ABCMeta
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch

from .config import LLMModelConfig, LLMModelInput


@dataclass
class LLMCache:
class LLMCache(torch.nn.Module):
def __init__(self):
super().__init__()

def update(
self,
key_states: torch.Tensor,
Expand Down
75 changes: 73 additions & 2 deletions mlora/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,50 @@ def _upad_input(
)


def prepare_fa2_from_position_ids(query, key, value, position_ids):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)

cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)

max_length = position_ids.max() + 1

return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)


def flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: Optional[bool] = None,
):
if not use_top_left_mask:
causal = is_causal
Expand All @@ -181,6 +213,9 @@ def flash_attention_forward(
{"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
)

if deterministic is not None:
flash_kwargs["deterministic"] = deterministic

if softcap is not None:
flash_kwargs["softcap"] = softcap

Expand Down Expand Up @@ -209,12 +244,48 @@ def flash_attention_forward(
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None
and not (position_ids[:, -1] == position_ids.size(1) - 1).all()
and query_length != 1
):
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)

attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)

else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout,
dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
Expand Down
Loading

0 comments on commit 23b8c54

Please sign in to comment.