From 54097b98ef2ff1f5a0ab34fe2bb31de9063312f4 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 11 Dec 2024 21:22:15 +0000 Subject: [PATCH] support turbomind ep --- lmdeploy/messages.py | 1 + lmdeploy/turbomind/deploy/config.py | 1 + lmdeploy/turbomind/deploy/module.py | 12 ++++---- src/turbomind/kernels/gemm/moe_utils_v2.cu | 28 ++++++++++++++++++- src/turbomind/kernels/gemm/moe_utils_v2.h | 3 ++ .../kernels/gemm/test/test_moe_utils.cu | 1 + .../models/llama/LlamaDecoderLayerWeight.cc | 19 ++++++++++--- src/turbomind/models/llama/LlamaDenseWeight.h | 8 ++++-- src/turbomind/models/llama/llama_params.h | 1 + src/turbomind/models/llama/moe_ffn_layer.cc | 24 +++++++++++----- src/turbomind/models/llama/moe_ffn_layer.h | 4 ++- .../triton_backend/llama/LlamaTritonModel.cc | 1 + 12 files changed, 83 insertions(+), 20 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 90823598ea..803d600cc3 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -196,6 +196,7 @@ class TurbomindEngineConfig: cache_chunk_size: int = -1 cache_block_seq_len: int = 64 enable_prefix_caching: bool = False + enable_ep: bool = False quant_policy: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: bool = False diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index e483500e96..6fc7796fa4 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -55,6 +55,7 @@ class ModelConfig: session_len: int = None tp: int = 1 model_format: str = 'hf' + enable_ep: bool = False expert_num: List[int] = () expert_inter_size: int = 0 experts_per_token: int = 0 diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 52497175ef..8628ba7497 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -108,7 +108,8 @@ def _export(self, w123, kind: str, pack_fn, - apply_gs=False): + apply_gs=False, + enable_ep=False): is_lora_a, is_lora_b = get_lora_flags(kind) w1, w2, w3 = map(transpose, w123) @@ -122,15 +123,15 @@ def _export(self, w1, w2, w3 = map(pack_fn, (w1, w2, w3)) self.model.save_split(w1, fmt.format(idx, 'w1', kind), - split_dim=-1, + split_dim=-1 if not enable_ep else None, copy=is_lora_a) self.model.save_split(w3, fmt.format(idx, 'w3', kind), - split_dim=-1, + split_dim=-1 if not enable_ep else None, copy=is_lora_a) self.model.save_split(w2, fmt.format(idx, 'w2', kind), - split_dim=0, + split_dim=0 if not enable_ep else None, copy=is_lora_b) def apply(self, i: int, r: BaseReader): @@ -163,7 +164,8 @@ def apply(self, i: int, r: BaseReader): for p in get_params(r.moe_ffn_expert()): for e in range(self.expert_num[i]): fmt = self._moe_ffn_expert.replace('E', str(e)) - p(partial(self._export, self.inter_size, fmt), + # TODO: pass enable_ep + p(partial(self._export, self.inter_size, fmt, enable_ep=True), partial(r.moe_ffn_expert, e, i), i) gate = transpose(r.moe_ffn_gate(i)) diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu index a9e4f7da51..6b158c9cd4 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.cu +++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu @@ -265,6 +265,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] int expert_num, int top_k, bool norm_topk, + int2 expert_range, float routed_scale) { constexpr int max_tiles = kMoeGateMaxTiles; @@ -537,8 +538,13 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] const float scale = smem.shared_scales[idx][bti2]; if (ti2 < token_num && idx < top_k) { + if (expert_id >= expert_range.x && expert_id < expert_range.y) { + scales[idx * token_num + ti2] = scale * routed_scale; + } + else { + scales[idx * token_num + ti2] = 0; + } masks[expert_id * token_num_padded + ti2] = idx; - scales[idx * token_num + ti2] = scale * routed_scale; atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1); } } @@ -570,6 +576,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n int experts_per_token, bool norm_topk, float routed_scale, + int2 expert_range, cudaStream_t st) { constexpr int base_log_tile = 9; @@ -602,6 +609,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n experts, experts_per_token, norm_topk, + expert_range, routed_scale); }; @@ -962,4 +970,22 @@ void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int g std::abort(); } +__global__ void moveOffsets(int* offsets, int expert, int2 expert_range) +{ + int thread_id = threadIdx.x; + int local_expert_num = expert_range.y - expert_range.x; + for (int i = threadIdx.x; i <= local_expert_num; i += blockDim.x) { + offsets[i] = offsets[expert_range.x + i]; + } + __syncthreads(); + for (int i = threadIdx.x + local_expert_num; i < expert; i += blockDim.x) { + offsets[i + 1] = offsets[local_expert_num]; + } +} + +void invokeMoveOffsets(int* offsets, int expert, int2 expert_range, cudaStream_t st) +{ + moveOffsets<<<1, 32, 0, st>>>(offsets, expert, expert_range); +} + } // namespace turbomind diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.h b/src/turbomind/kernels/gemm/moe_utils_v2.h index d53de1354e..706b3920bf 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.h +++ b/src/turbomind/kernels/gemm/moe_utils_v2.h @@ -23,6 +23,7 @@ void invokeMoeGate_V2(int* f2n, int exp_per_tok, bool norm_topk, float routed_scale, + int2 expert_range, cudaStream_t st); template @@ -65,4 +66,6 @@ std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, s std::vector SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g); +void invokeMoveOffsets(int* offsets, int expert, int2 expert_range, cudaStream_t st); + } // namespace turbomind diff --git a/src/turbomind/kernels/gemm/test/test_moe_utils.cu b/src/turbomind/kernels/gemm/test/test_moe_utils.cu index 4b2ea6a83a..873d17211d 100644 --- a/src/turbomind/kernels/gemm/test/test_moe_utils.cu +++ b/src/turbomind/kernels/gemm/test/test_moe_utils.cu @@ -224,6 +224,7 @@ bool test_moe_gate(int tokens, // experts_per_token, false, 1.f, + {0, expert_num}, // expert_offset nullptr); } diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 0a2a3be175..f2989aa132 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -397,11 +397,22 @@ TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) concat(prefix, "moe_ffn.gate.weight"), Tensor{MEMORY_GPU, getTensorType(), {moe_weights.gate.kernel_size()}, moe_weights.gate.kernel}); auto& experts = moe_weights.experts; + + auto moe_prefix = [=](const std::string name, int moe_tp) { + return moe_tp == -1 ? name : concat(name, moe_tp); + }; + + int moe_start_id = 0; + int moe_tp = tensor_para_rank_; + if (moe_weights.enable_ep) { + moe_start_id = experts.size() * tensor_para_rank_; + moe_tp = -1; + } for (size_t i = 0; i < experts.size(); ++i) { - const std::string name = "moe_ffn.experts." + std::to_string(i); - getWeightTensor(experts[i].gating, false, get_prefix(concat(name, "w1")), output); - getWeightTensor(experts[i].intermediate, false, get_prefix(concat(name, "w3")), output); - getWeightTensor(experts[i].output, false, get_prefix(concat(name, "w2")), output); + const std::string name = "moe_ffn.experts." + std::to_string(moe_start_id + i); + getWeightTensor(experts[i].gating, false, moe_prefix(concat(prefix, name, "w1"), moe_tp), output); + getWeightTensor(experts[i].intermediate, false, moe_prefix(concat(prefix, name, "w3"), moe_tp), output); + getWeightTensor(experts[i].output, false, moe_prefix(concat(prefix, name, "w2"), moe_tp), output); } if (moe_weights.shared_gate.kernel) { output.insert(concat(prefix, "moe_ffn.shared_gate.weight"), diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 944781bf5d..9ef9bee257 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -290,6 +290,8 @@ struct MoeFfnWeight { return; } + enable_ep = param.enable_ep; + // printf("%d %d %d\n", (int)hidden_dim, (int)param.inter_size, (int)expert_num); gate.input_dims = hidden_dim; @@ -297,14 +299,15 @@ struct MoeFfnWeight { gate.type = get_default_weight_type(); gate.group_size = group_size; - experts.resize(expert_num); + experts.resize(enable_ep ? expert_num / tp : expert_num); method = param.method; fuse_silu_act = fuse_silu_act && method == MoeParam::kFused; for (auto& e : experts) { // inter size is divided by tp in `FfnWeight` - e = LlamaFfnWeight{hidden_dim, (size_t)param.inter_size, tp, weight_type, group_size, fuse_silu_act}; + size_t divide = enable_ep ? 1 : tp; + e = LlamaFfnWeight{hidden_dim, (size_t)param.inter_size, divide, weight_type, group_size, fuse_silu_act}; } if (param.shared_gate) { @@ -339,6 +342,7 @@ struct MoeFfnWeight { block.free(st); } + bool enable_ep; LlamaDenseWeight gate; std::vector> experts; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 0a505b11a9..5b536467b4 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -46,6 +46,7 @@ struct MoeParam { kFused } method; + bool enable_ep; int experts_per_token; int inter_size; bool norm_topk_prob; diff --git a/src/turbomind/models/llama/moe_ffn_layer.cc b/src/turbomind/models/llama/moe_ffn_layer.cc index 390d147540..76762268cc 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.cc +++ b/src/turbomind/models/llama/moe_ffn_layer.cc @@ -80,7 +80,8 @@ template void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id, const MoeFfnWeight& moe) { const size_t padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize; - const int expert_num = moe.experts.size(); + const int expert_num = param_.enable_ep ? moe.experts.size() * tensor_para_.world_size_ : moe.experts.size(); + const int local_expert_num = moe.experts.size(); FT_CHECK(expert_num); @@ -115,6 +116,11 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id sync_check_cuda_error(); } + expert_range_ = {0, local_expert_num}; + if (param_.enable_ep) { + expert_range_ = {tensor_para_.rank_ * local_expert_num, (1 + tensor_para_.rank_) * local_expert_num}; + } + /// TODO: fix illegal memory access even if NaN are present in logits invokeMoeGate_V2(f2n_, en2f_, @@ -129,6 +135,7 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id param_.experts_per_token, param_.norm_topk_prob, param_.routed_scale, + expert_range_, stream_); sync_check_cuda_error(); @@ -147,22 +154,25 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id cudaMemcpyAsync(offsets_, h_offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_)); } + if (param_.enable_ep) { + invokeMoveOffsets(offsets_, expert_num, expert_range_, stream_); + } + if (param_.method == MoeParam::kNaive) { dispatchMoeGather(inout_buf_, input, f2n_, tokens, param_.experts_per_token, hidden_dim_, stream_); sync_check_cuda_error(); check_cuda_error( - cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_)); + cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (local_expert_num + 1), cudaMemcpyDefault, stream_)); check_cuda_error(cudaStreamSynchronize(stream_)); - if (h_offsets_[expert_num] != tokens * param_.experts_per_token) { - FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[expert_num], tokens * param_.experts_per_token)); + if (!param_.enable_ep && h_offsets_[local_expert_num] != tokens * param_.experts_per_token) { + FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[local_expert_num], tokens * param_.experts_per_token)); } - for (int i = 0; i < expert_num; ++i) { - + for (int i = 0; i < local_expert_num; ++i) { FT_CHECK(moe.experts[i].is_fused_silu == false); if (size_t count = h_offsets_[i + 1] - h_offsets_[i]) { @@ -177,7 +187,7 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id } } else { - context_->update(expert_num, param_.experts_per_token, offsets_); + context_->update(local_expert_num, param_.experts_per_token, offsets_); auto& block = moe.block; diff --git a/src/turbomind/models/llama/moe_ffn_layer.h b/src/turbomind/models/llama/moe_ffn_layer.h index 74c62d004b..9db1758592 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.h +++ b/src/turbomind/models/llama/moe_ffn_layer.h @@ -17,7 +17,7 @@ template class MoeFfnLayer { public: MoeFfnLayer(ModelParam model, const MoeParam& param, const NcclParam& tp, const Context& ctx): - inter_size_(param.inter_size / tp.world_size_), + inter_size_(param.enable_ep ? param.inter_size : param.inter_size / tp.world_size_), hidden_dim_(model.hidden_units), param_(param), dtype_(getTensorType()), @@ -93,6 +93,8 @@ class MoeFfnLayer { int* accum_{}; int* offsets_{}; + + int2 expert_range_{}; }; } // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 1c7c5eb468..0c7c8d4e6c 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -314,6 +314,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, lora_param_.scale_pattern = getLoraPattern(lora_reader["lora_scale_pattern"].as(""), [](const std::string& s) { return std::stof(s); }); + moe_param_.enable_ep = model_reader["enable_ep"].as(false); moe_param_.experts_per_token = model_reader["experts_per_token"].as(0); moe_param_.inter_size = model_reader["expert_inter_size"].as(0); moe_param_.shared_gate = model_reader["moe_shared_gate"].as();