Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
improve gating algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Aug 16, 2024
1 parent 5f12f95 commit de2237d
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 38 deletions.
2 changes: 1 addition & 1 deletion mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _prepare_inputs(
)

# prepare mask
if input_args.batch_masks_ is not None and 1 in input_args.batch_masks_:
if input_args.batch_masks_ is not None:
# 2d mask is passed through the layers
if isinstance(input_args.batch_masks_, torch.Tensor):
attention_mask = input_args.batch_masks_.to(
Expand Down
6 changes: 2 additions & 4 deletions mlora/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,9 @@ def check(self) -> "MixLoraConfig":
assert isinstance(self.top_k_, int) and self.top_k_ > 0
elif self.routing_strategy_ == "mixlora-dynamic":
assert (
isinstance(self.top_p_, float)
and self.top_p_ > 0.0
and self.top_p_ <= 1.0
isinstance(self.top_p_, float) and self.top_p_ > 0 and self.top_p_ <= 1
)
assert isinstance(self.temperature_, float) and self.temperature_ >= 0.0
assert isinstance(self.temperature_, float) and self.temperature_ >= 0
elif self.routing_strategy_ == "mixlora-switch":
assert (
isinstance(self.router_z_loss_coef_, float)
Expand Down
76 changes: 47 additions & 29 deletions mlora/modules/mix_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@ def _mixlora_compatible_forward(
return final_expert_states


def _unpack_router_logits(gate_logits: List[torch.Tensor]):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
return concatenated_gate_logits


def _mixtral_load_balancing_loss_func(
gate_logits: List[torch.Tensor],
num_experts: int,
top_k: int,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
concatenated_gate_logits = _unpack_router_logits(gate_logits)

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

Expand All @@ -67,9 +72,7 @@ def _mixtral_load_balancing_loss_func(
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (
batch_size * sequence_length
)
num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)

# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
Expand All @@ -78,7 +81,7 @@ def _mixtral_load_balancing_loss_func(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
)
.reshape(-1, top_k, num_experts)
.to(compute_device)
.to(routing_weights.device)
)

# Compute the percentage of tokens routed to each experts
Expand All @@ -91,7 +94,7 @@ def _mixtral_load_balancing_loss_func(
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
.to(routing_weights.device)
)

# Compute the average probability of routing to these experts
Expand Down Expand Up @@ -260,33 +263,40 @@ def forward(
return final_hidden_states, router_logits


def _top_p(router_logits: torch.Tensor, p: float, temperature: float = 0.0):
def _dynamic_top_p(router_logits: torch.Tensor, top_p: float, temperature: float = 0.0):
if temperature > 0.0:
router_logits = router_logits / temperature
sorted_logits, sorted_indices = torch.sort(router_logits, dim=-1, descending=True)
cumulative_probs = sorted_logits.cumsum(dim=-1)
expert_index = cumulative_probs > p
expert_index = expert_index.long().argmax(dim=-1)
dynamic_top_k = max(expert_index.min(), 1)
return sorted_logits[..., :dynamic_top_k], sorted_indices[..., :dynamic_top_k]
expert_mask = cumulative_probs > top_p
threshold_indices = expert_mask.long().argmax(dim=-1)
threshold_mask = torch.nn.functional.one_hot(
threshold_indices, num_classes=sorted_indices.size(-1)
).bool()
expert_mask = expert_mask & ~threshold_mask
sorted_logits = sorted_logits.masked_fill(expert_mask, 0.0)
sorted_indices = sorted_indices.masked_fill(expert_mask, -1)
return sorted_logits, sorted_indices


def _dynamic_load_balancing_loss_func(
gate_logits: List[torch.Tensor],
routing_weights: torch.Tensor,
num_experts: int,
top_p: float,
temperature: float,
) -> float:
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
_, selected_experts = _dynamic_top_p(routing_weights, top_p, temperature)

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
expert_mask = torch.empty(
(num_experts, num_experts, routing_weights.size(0)),
dtype=routing_weights.dtype,
device=routing_weights.device,
)

_, selected_experts = _top_p(routing_weights, top_p, temperature)
for expert_idx in range(num_experts):
expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = expert_mask.permute(2, 1, 0)

# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
Expand All @@ -307,8 +317,13 @@ def __init__(self, config: MixLoraConfig) -> None:
self.temperature = config.temperature_

def forward(self, gate_logits, attention_mask) -> torch.Tensor:
concatenated_gate_logits = _unpack_router_logits(gate_logits)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
return self.aux_loss_coef * _dynamic_load_balancing_loss_func(
gate_logits, self.experts, self.top_p, self.temperature
routing_weights,
self.experts,
self.top_p,
self.temperature,
)


Expand Down Expand Up @@ -398,7 +413,7 @@ def forward(
router_logits = self.gate_(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=self.dtype_)
routing_weights, selected_experts = _top_p(
routing_weights, selected_experts = _dynamic_top_p(
routing_weights, self.top_p_, self.temperature_
)

Expand All @@ -410,11 +425,14 @@ def forward(
device=hidden_states.device,
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.experts_
).permute(2, 1, 0)
expert_mask = torch.empty(
(self.experts_, self.experts_, batch_size * sequence_length),
dtype=self.dtype_,
device=hidden_states.device,
)

for expert_idx in range(self.experts_):
expert_mask[expert_idx] = (selected_experts == expert_idx).transpose(0, 1)

# Perform the computation on each expert
if input_args.efficient_operator_ and hasattr(ffn_layer, "_mixlora_forward"):
Expand Down
7 changes: 4 additions & 3 deletions mlora/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def encode(
def decode(self, data: Tokens) -> str:
return self.tokenizer.decode(data)

# get the mask from tokens
# Get the mask from tokens
# https://huggingface.co/docs/transformers/glossary#attention-mask
# example: tokens is [2, 3, pad, pad, 4, 5]
# output is [0, 0, 1, 1, 0, 0]
# output is [1, 1, 0, 0, 1, 1]
def mask_from(self, tokens: Tokens) -> Masks:
mask_tokens = [self.pad_id_]
return [int(tok in mask_tokens) for tok in tokens]
return [int(tok not in mask_tokens) for tok in tokens]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mlora"
version = "0.5.1"
version = "0.5.2"
description = "An Efficient Factory to Build Multiple LoRA Adapters"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
37 changes: 37 additions & 0 deletions templates/mixlora_dynamic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"cutoff_len": 512,
"save_step": 1000,
"train_lora_candidate_num": 2,
"train_lora_simultaneously_num": 2,
"train_strategy": "optim",
"lora": [
{
"name": "mixlora",
"task_name": "",
"optim": "adamw",
"scheduler_type": "constant",
"warmup_steps": 0,
"lr": 2e-4,
"batch_size": 16,
"micro_batch_size": 8,
"evaluate_batch_size": 16,
"num_epochs": 2,
"r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"target_modules": {
"q_proj": true,
"k_proj": true,
"v_proj": true,
"o_proj": true,
"gate_proj": true,
"down_proj": true,
"up_proj": true
},
"routing_strategy": "mixlora-dynamic",
"num_experts": 8,
"top_p": 0.8,
"group_by_length": false
}
]
}

0 comments on commit de2237d

Please sign in to comment.