Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[feature] support dynamic MixLoRA (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee authored Aug 16, 2024
1 parent 74e4bc1 commit 8fd223f
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 47 deletions.
23 changes: 13 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,24 @@ You can use the `MLORA_BACKEND_TYPE` environment variable to force m-LoRA to use

## Supported PEFT Methods

| | PEFT Methods | Arguments* |
|---------|----------------------------------------------------------|-----------------------------------------------------|
| ✓ | [QLoRA](https://arxiv.org/abs/2402.12354) | See *Quantize Methods* |
| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `"loraplus_lr_ratio": 20.0` |
| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `"use_dora": true` |
| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `"use_rslora": true` |
| ✓ | [MoLA](https://arxiv.org/abs/2402.08562) | `"routing_strategy": "mola", "num_experts": 8` |
| ✓ | [LoRAMoE](https://arxiv.org/abs/2312.09979) | `"routing_strategy": "loramoe", "num_experts": 8` |
| ✓ | [MixLoRA](https://arxiv.org/abs/2404.15159) | See [MixLoRA](https://github.com/TUDB-Labs/MixLoRA) |
| | PEFT Methods | Arguments* |
|---------|----------------------------------------------------------|-----------------------------------------------------------|
| ✓ | [QLoRA](https://arxiv.org/abs/2402.12354) | See *Quantize Methods* |
| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354) | `"loraplus_lr_ratio": 20.0` |
| ✓ | [DoRA](https://arxiv.org/abs/2402.09353) | `"use_dora": true` |
| ✓ | [rsLoRA](https://arxiv.org/abs/2312.03732) | `"use_rslora": true` |
| ✓ | [MoLA](https://arxiv.org/abs/2402.08562) | `"routing_strategy": "mola", "num_experts": 8` |
| ✓ | [LoRAMoE](https://arxiv.org/abs/2312.09979) | `"routing_strategy": "loramoe", "num_experts": 8` |
| ✓ | [MixLoRA](https://arxiv.org/abs/2404.15159) | `"routing_strategy": "mixlora", "num_experts": 8` |
| ✓ | MixLoRA-Dynamic | `"routing_strategy": "mixlora-dynamic", "num_experts": 8` |
| ✓ | MixLoRA-Switch | `"routing_strategy": "mixlora-switch", "num_experts": 8` |

*: Arguments of configuration file

### Notice of PEFT supports
1. m-LoRA supports specific optimized operators for these PEFT methods, which can effectively improve the computing performance during training, evaluation and inference. However, these operators may cause a certain degree of accuracy loss (less than 5%). You can disable the optimized operators by defining the `MLORA_EVALUATE_MODE` environment variable in advance.
2. Auxiliary Loss is not currently supported for Mo-LoRA (Mixture of LoRAs) methods other than MixLoRA
2. Auxiliary Loss is not currently supported for Mo-LoRA (Mixture of LoRAs) methods other than MixLoRA.
3. You can check detailed arguments of MixLoRA in [TUDB-Labs/MixLoRA](https://github.com/TUDB-Labs/MixLoRA).

## Supported Attention Methods

Expand Down
2 changes: 1 addition & 1 deletion mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _prepare_inputs(
)

# prepare mask
if input_args.batch_masks_ is not None and 1 in input_args.batch_masks_:
if input_args.batch_masks_ is not None:
# 2d mask is passed through the layers
if isinstance(input_args.batch_masks_, torch.Tensor):
attention_mask = input_args.batch_masks_.to(
Expand Down
4 changes: 4 additions & 0 deletions mlora/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@

# MixLoRA MoEs
from .lora_moes import (
DynamicRouterLoss,
DynamicSparseMoe,
LoraMoe,
MixtralRouterLoss,
MixtralSparseMoe,
Expand Down Expand Up @@ -85,6 +87,8 @@
"Linear",
"MixtralRouterLoss",
"MixtralSparseMoe",
"DynamicRouterLoss",
"DynamicSparseMoe",
"SwitchRouterLoss",
"SwitchSparseMoe",
"LoraMoe",
Expand Down
19 changes: 18 additions & 1 deletion mlora/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def export(self) -> Dict[str, any]:
return config


available_routing_strategies = ["mixlora", "mixlora-switch"]
available_routing_strategies = ["mixlora", "mixlora-dynamic", "mixlora-switch"]


@dataclass
Expand All @@ -219,6 +219,9 @@ class MixLoraConfig(LoraConfig):
act_fn_: Optional[Union[str, torch.nn.Module]] = None
# mixtral config
top_k_: int = None
# dynamic config
top_p_: float = None
temperature_: float = None
# switch transformers config
router_z_loss_coef_: float = None
expert_capacity_: int = None
Expand Down Expand Up @@ -248,6 +251,11 @@ def check(self) -> "MixLoraConfig":
)
if self.routing_strategy_ == "mixlora":
assert isinstance(self.top_k_, int) and self.top_k_ > 0
elif self.routing_strategy_ == "mixlora-dynamic":
assert (
isinstance(self.top_p_, float) and self.top_p_ > 0 and self.top_p_ <= 1
)
assert isinstance(self.temperature_, float) and self.temperature_ >= 0
elif self.routing_strategy_ == "mixlora-switch":
assert (
isinstance(self.router_z_loss_coef_, float)
Expand Down Expand Up @@ -280,6 +288,11 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_k_ = config.get("top_k", 2)
elif lora_config.routing_strategy_ == "mixlora-dynamic":
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_p_ = config.get("top_p", 0.8)
lora_config.temperature_ = config.get("temperature", 0.0)
elif lora_config.routing_strategy_ == "mixlora-switch":
lora_config.router_init_range_ = config.get("router_init_range", 1.0)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.01)
Expand Down Expand Up @@ -308,6 +321,9 @@ def export(self) -> Dict[str, any]:
config["act_fn"] = self.act_fn_
if self.routing_strategy_ == "mixlora":
config["top_k"] = self.top_k_
elif self.routing_strategy_ == "mixlora-dynamic":
config["top_p"] = self.top_p_
config["temperature"] = self.temperature_
elif self.routing_strategy_ == "mixlora-switch":
config["expert_capacity"] = self.expert_capacity_
config["sparse_step"] = self.sparse_step_
Expand Down Expand Up @@ -408,6 +424,7 @@ def expert_config(self, expert_idx: int) -> LoraConfig:

routing_strategy_dict = {
"mixlora": MixLoraConfig,
"mixlora-dynamic": MixLoraConfig,
"mixlora-switch": MixLoraConfig,
"loramoe": LoraMoeConfig,
"mola": MolaConfig,
Expand Down
9 changes: 8 additions & 1 deletion mlora/modules/lora_moes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .config import LoraMoeConfig, MixLoraConfig, MolaConfig
from .lora_linear import Linear
from .mix_lora import (
DynamicRouterLoss,
DynamicSparseMoe,
MixtralRouterLoss,
MixtralSparseMoe,
SwitchRouterLoss,
Expand Down Expand Up @@ -149,7 +151,11 @@ def forward(
return residual + final_hidden_states


router_loss_dict = {"mixlora": MixtralRouterLoss, "mixlora-switch": SwitchRouterLoss}
router_loss_dict = {
"mixlora": MixtralRouterLoss,
"mixlora-dynamic": DynamicRouterLoss,
"mixlora-switch": SwitchRouterLoss,
}


def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:
Expand All @@ -163,6 +169,7 @@ def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:

moe_layer_dict = {
"mixlora": MixtralSparseMoe,
"mixlora-dynamic": DynamicSparseMoe,
"mixlora-switch": SwitchSparseMoe,
"loramoe": LoraMoe,
"mola": MolaSparseMoe,
Expand Down
Loading

0 comments on commit 8fd223f

Please sign in to comment.