Skip to content

Commit

Permalink
support turbomind ep
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 11, 2024
1 parent 01f82e0 commit 54097b9
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 20 deletions.
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class TurbomindEngineConfig:
cache_chunk_size: int = -1
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
enable_ep: bool = False
quant_policy: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: bool = False
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ModelConfig:
session_len: int = None
tp: int = 1
model_format: str = 'hf'
enable_ep: bool = False
expert_num: List[int] = ()
expert_inter_size: int = 0
experts_per_token: int = 0
Expand Down
12 changes: 7 additions & 5 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def _export(self,
w123,
kind: str,
pack_fn,
apply_gs=False):
apply_gs=False,
enable_ep=False):
is_lora_a, is_lora_b = get_lora_flags(kind)
w1, w2, w3 = map(transpose, w123)

Expand All @@ -122,15 +123,15 @@ def _export(self,
w1, w2, w3 = map(pack_fn, (w1, w2, w3))
self.model.save_split(w1,
fmt.format(idx, 'w1', kind),
split_dim=-1,
split_dim=-1 if not enable_ep else None,
copy=is_lora_a)
self.model.save_split(w3,
fmt.format(idx, 'w3', kind),
split_dim=-1,
split_dim=-1 if not enable_ep else None,
copy=is_lora_a)
self.model.save_split(w2,
fmt.format(idx, 'w2', kind),
split_dim=0,
split_dim=0 if not enable_ep else None,
copy=is_lora_b)

def apply(self, i: int, r: BaseReader):
Expand Down Expand Up @@ -163,7 +164,8 @@ def apply(self, i: int, r: BaseReader):
for p in get_params(r.moe_ffn_expert()):
for e in range(self.expert_num[i]):
fmt = self._moe_ffn_expert.replace('E', str(e))
p(partial(self._export, self.inter_size, fmt),
# TODO: pass enable_ep
p(partial(self._export, self.inter_size, fmt, enable_ep=True),
partial(r.moe_ffn_expert, e, i), i)

gate = transpose(r.moe_ffn_gate(i))
Expand Down
28 changes: 27 additions & 1 deletion src/turbomind/kernels/gemm/moe_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
int expert_num,
int top_k,
bool norm_topk,
int2 expert_range,
float routed_scale)
{
constexpr int max_tiles = kMoeGateMaxTiles;
Expand Down Expand Up @@ -537,8 +538,13 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
const float scale = smem.shared_scales[idx][bti2];

if (ti2 < token_num && idx < top_k) {
if (expert_id >= expert_range.x && expert_id < expert_range.y) {
scales[idx * token_num + ti2] = scale * routed_scale;
}
else {
scales[idx * token_num + ti2] = 0;
}
masks[expert_id * token_num_padded + ti2] = idx;
scales[idx * token_num + ti2] = scale * routed_scale;
atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1);
}
}
Expand Down Expand Up @@ -570,6 +576,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
int experts_per_token,
bool norm_topk,
float routed_scale,
int2 expert_range,
cudaStream_t st)
{
constexpr int base_log_tile = 9;
Expand Down Expand Up @@ -602,6 +609,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
experts,
experts_per_token,
norm_topk,
expert_range,
routed_scale);
};

Expand Down Expand Up @@ -962,4 +970,22 @@ void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int g
std::abort();
}

__global__ void moveOffsets(int* offsets, int expert, int2 expert_range)
{
int thread_id = threadIdx.x;
int local_expert_num = expert_range.y - expert_range.x;
for (int i = threadIdx.x; i <= local_expert_num; i += blockDim.x) {
offsets[i] = offsets[expert_range.x + i];
}
__syncthreads();
for (int i = threadIdx.x + local_expert_num; i < expert; i += blockDim.x) {
offsets[i + 1] = offsets[local_expert_num];
}
}

void invokeMoveOffsets(int* offsets, int expert, int2 expert_range, cudaStream_t st)
{
moveOffsets<<<1, 32, 0, st>>>(offsets, expert, expert_range);
}

} // namespace turbomind
3 changes: 3 additions & 0 deletions src/turbomind/kernels/gemm/moe_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void invokeMoeGate_V2(int* f2n,
int exp_per_tok,
bool norm_topk,
float routed_scale,
int2 expert_range,
cudaStream_t st);

template<class T>
Expand Down Expand Up @@ -65,4 +66,6 @@ std::vector<int> SampleUniform(int token_num, int expert_num, int exp_per_tok, s

std::vector<int> SampleBalanced(int token_num, int expert_num, int exp_per_tok, std::mt19937& g);

void invokeMoveOffsets(int* offsets, int expert, int2 expert_range, cudaStream_t st);

} // namespace turbomind
1 change: 1 addition & 0 deletions src/turbomind/kernels/gemm/test/test_moe_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ bool test_moe_gate(int tokens, //
experts_per_token,
false,
1.f,
{0, expert_num}, // expert_offset
nullptr);
}

Expand Down
19 changes: 15 additions & 4 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,22 @@ TensorMap LlamaDecoderLayerWeight<T>::getParams(std::string prefix)
concat(prefix, "moe_ffn.gate.weight"),
Tensor{MEMORY_GPU, getTensorType<T>(), {moe_weights.gate.kernel_size()}, moe_weights.gate.kernel});
auto& experts = moe_weights.experts;

auto moe_prefix = [=](const std::string name, int moe_tp) {
return moe_tp == -1 ? name : concat(name, moe_tp);
};

int moe_start_id = 0;
int moe_tp = tensor_para_rank_;
if (moe_weights.enable_ep) {
moe_start_id = experts.size() * tensor_para_rank_;
moe_tp = -1;
}
for (size_t i = 0; i < experts.size(); ++i) {
const std::string name = "moe_ffn.experts." + std::to_string(i);
getWeightTensor(experts[i].gating, false, get_prefix(concat(name, "w1")), output);
getWeightTensor(experts[i].intermediate, false, get_prefix(concat(name, "w3")), output);
getWeightTensor(experts[i].output, false, get_prefix(concat(name, "w2")), output);
const std::string name = "moe_ffn.experts." + std::to_string(moe_start_id + i);
getWeightTensor(experts[i].gating, false, moe_prefix(concat(prefix, name, "w1"), moe_tp), output);
getWeightTensor(experts[i].intermediate, false, moe_prefix(concat(prefix, name, "w3"), moe_tp), output);
getWeightTensor(experts[i].output, false, moe_prefix(concat(prefix, name, "w2"), moe_tp), output);
}
if (moe_weights.shared_gate.kernel) {
output.insert(concat(prefix, "moe_ffn.shared_gate.weight"),
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/models/llama/LlamaDenseWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,21 +290,24 @@ struct MoeFfnWeight {
return;
}

enable_ep = param.enable_ep;

// printf("%d %d %d\n", (int)hidden_dim, (int)param.inter_size, (int)expert_num);

gate.input_dims = hidden_dim;
gate.output_dims = expert_num;
gate.type = get_default_weight_type<T>();
gate.group_size = group_size;

experts.resize(expert_num);
experts.resize(enable_ep ? expert_num / tp : expert_num);

method = param.method;
fuse_silu_act = fuse_silu_act && method == MoeParam::kFused;

for (auto& e : experts) {
// inter size is divided by tp in `FfnWeight`
e = LlamaFfnWeight<T>{hidden_dim, (size_t)param.inter_size, tp, weight_type, group_size, fuse_silu_act};
size_t divide = enable_ep ? 1 : tp;
e = LlamaFfnWeight<T>{hidden_dim, (size_t)param.inter_size, divide, weight_type, group_size, fuse_silu_act};
}

if (param.shared_gate) {
Expand Down Expand Up @@ -339,6 +342,7 @@ struct MoeFfnWeight {
block.free(st);
}

bool enable_ep;
LlamaDenseWeight<T> gate;
std::vector<LlamaFfnWeight<T>> experts;

Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct MoeParam {
kFused
} method;

bool enable_ep;
int experts_per_token;
int inter_size;
bool norm_topk_prob;
Expand Down
24 changes: 17 additions & 7 deletions src/turbomind/models/llama/moe_ffn_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ template<class T>
void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id, const MoeFfnWeight<T>& moe)
{
const size_t padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize;
const int expert_num = moe.experts.size();
const int expert_num = param_.enable_ep ? moe.experts.size() * tensor_para_.world_size_ : moe.experts.size();
const int local_expert_num = moe.experts.size();

FT_CHECK(expert_num);

Expand Down Expand Up @@ -115,6 +116,11 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id
sync_check_cuda_error();
}

expert_range_ = {0, local_expert_num};
if (param_.enable_ep) {
expert_range_ = {tensor_para_.rank_ * local_expert_num, (1 + tensor_para_.rank_) * local_expert_num};
}

/// TODO: fix illegal memory access even if NaN are present in logits
invokeMoeGate_V2(f2n_,
en2f_,
Expand All @@ -129,6 +135,7 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id
param_.experts_per_token,
param_.norm_topk_prob,
param_.routed_scale,
expert_range_,
stream_);
sync_check_cuda_error();

Expand All @@ -147,22 +154,25 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id
cudaMemcpyAsync(offsets_, h_offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_));
}

if (param_.enable_ep) {
invokeMoveOffsets(offsets_, expert_num, expert_range_, stream_);
}

if (param_.method == MoeParam::kNaive) {

dispatchMoeGather(inout_buf_, input, f2n_, tokens, param_.experts_per_token, hidden_dim_, stream_);
sync_check_cuda_error();

check_cuda_error(
cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_));
cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (local_expert_num + 1), cudaMemcpyDefault, stream_));

check_cuda_error(cudaStreamSynchronize(stream_));

if (h_offsets_[expert_num] != tokens * param_.experts_per_token) {
FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[expert_num], tokens * param_.experts_per_token));
if (!param_.enable_ep && h_offsets_[local_expert_num] != tokens * param_.experts_per_token) {
FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[local_expert_num], tokens * param_.experts_per_token));
}

for (int i = 0; i < expert_num; ++i) {

for (int i = 0; i < local_expert_num; ++i) {
FT_CHECK(moe.experts[i].is_fused_silu == false);

if (size_t count = h_offsets_[i + 1] - h_offsets_[i]) {
Expand All @@ -177,7 +187,7 @@ void MoeFfnLayer<T>::forward(T* output, const T* input, int tokens, int layer_id
}
}
else {
context_->update(expert_num, param_.experts_per_token, offsets_);
context_->update(local_expert_num, param_.experts_per_token, offsets_);

auto& block = moe.block;

Expand Down
4 changes: 3 additions & 1 deletion src/turbomind/models/llama/moe_ffn_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ template<class T>
class MoeFfnLayer {
public:
MoeFfnLayer(ModelParam model, const MoeParam& param, const NcclParam& tp, const Context<T>& ctx):
inter_size_(param.inter_size / tp.world_size_),
inter_size_(param.enable_ep ? param.inter_size : param.inter_size / tp.world_size_),
hidden_dim_(model.hidden_units),
param_(param),
dtype_(getTensorType<T>()),
Expand Down Expand Up @@ -93,6 +93,8 @@ class MoeFfnLayer {

int* accum_{};
int* offsets_{};

int2 expert_range_{};
};

} // namespace turbomind
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
lora_param_.scale_pattern = getLoraPattern<float>(lora_reader["lora_scale_pattern"].as<std::string>(""),
[](const std::string& s) { return std::stof(s); });

moe_param_.enable_ep = model_reader["enable_ep"].as<bool>(false);
moe_param_.experts_per_token = model_reader["experts_per_token"].as<int>(0);
moe_param_.inter_size = model_reader["expert_inter_size"].as<int>(0);
moe_param_.shared_gate = model_reader["moe_shared_gate"].as<bool>();
Expand Down

0 comments on commit 54097b9

Please sign in to comment.