From 0e4c315b1f4075835ad0ff23a91ae68173092a42 Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 2 Dec 2024 10:52:58 +0000 Subject: [PATCH] split rope params --- src/turbomind/models/llama/LlamaBatch.cc | 10 +- src/turbomind/models/llama/llama_params.h | 53 +++++++--- src/turbomind/models/llama/rotary_emb.cu | 99 ++++++++++--------- src/turbomind/models/llama/rotary_emb.h | 49 +++++---- .../models/llama/unified_attention_layer.cc | 4 +- src/turbomind/models/llama/unified_decoder.cc | 6 +- .../triton_backend/llama/LlamaTritonModel.cc | 41 ++++---- 7 files changed, 144 insertions(+), 118 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index ea321d06a0..9ea9187878 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -368,15 +368,15 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) // compute rope scaling factor if (r->start_flag) { - seq.rope_theta = model_->attn_param_.rotary_embedding_base; - if (model_->attn_param_.use_dynamic_ntk) { - auto scaling_factor = model_->attn_param_.rope_scaling_factor; + seq.rope_theta = model_->attn_param_.rope.base; + if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) { + auto scaling_factor = model_->attn_param_.rope.factor; if (scaling_factor >= 1.f) { // infer by current context length auto max_seq_len = state.h_context_length[idx]; - auto max_pos_emb = model_->attn_param_.max_position_embeddings; + auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings; if (max_seq_len > max_pos_emb) { scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); - float rope_dim = model_->attn_param_.rotary_embedding_dim; + float rope_dim = model_->attn_param_.rope.dim; seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", (long)seq.id, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 0a505b11a9..000ef82eff 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -59,22 +59,45 @@ struct MoeParam { std::vector expert_num; }; +enum class RotaryScalingType +{ + kDefault, + kLinear, + kDynamic, + kYarn, + kLlama3, +}; + +struct YarnRopeParam { + float attention_factor; + float beta_fast; + float beta_slow; +}; + +struct Llama3RopeParam { + float low_freq_factor; + float high_freq_factor; + int original_max_position_embeddings; +}; + struct AttentionParam { - int rotary_embedding_dim; - float rotary_embedding_base; - int max_position_embeddings; - float softmax_scale; - std::string rope_scaling_type; - int original_max_position_embeddings; - float rope_scaling_factor; - float low_freq_factor; - float high_freq_factor; - float attention_factor; - float beta_fast; - float beta_slow; - bool use_dynamic_ntk; - bool use_logn_attn; - int cache_block_seq_len; + float softmax_scale; + int cache_block_seq_len; + bool use_logn_attn; + // rope + struct { + // common + RotaryScalingType type; + int dim; + float base; + float factor; + int max_position_embeddings; + // special + union { + YarnRopeParam yarn; + Llama3RopeParam llama3; + }; + } rope; }; struct EngineParam { diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index 2ecec40a79..a0e119062d 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -114,8 +114,7 @@ RotaryScalingType GetRoPEType(const std::string& type) {"linear", RotaryScalingType::kLinear}, {"dynamic", RotaryScalingType::kDynamic}, {"yarn", RotaryScalingType::kYarn}, - {"llama3", RotaryScalingType::kLlama3}, - {"mrope", RotaryScalingType::kMrope}}; + {"llama3", RotaryScalingType::kLlama3}}; return lookup.at(type); } @@ -132,42 +131,52 @@ void RotaryEmbeddingV2::allocateBuffer(size_t token_num) RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator): stream_(stream), allocator_(allocator) { - type_ = GetRoPEType(param.rope_scaling_type); - dim_ = param.rotary_embedding_dim; - rope_scaling_factor_ = 1.0f; - attention_factor_ = 1.0f; + type_ = param.rope.type; + dim_ = param.rope.dim; - if (type_ == RotaryScalingType::kLinear) { - rope_scaling_factor_ /= param.rope_scaling_factor; - } - else if (type_ == RotaryScalingType::kLlama3) { - const double PI = 3.14159265358979323846; - float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor); - llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor; - llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor; - } - else if (type_ == RotaryScalingType::kYarn) { - const double PI = 3.14159265358979323846; - auto find_correction_dim = [&](float num_rotations) { - return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI))) - / (2 * std::log(param.rotary_embedding_base)); - }; - auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { - low = std::floor(find_correction_dim(low_rot)); - high = std::ceil(find_correction_dim(high_rot)); - low = std::max(low, 0.f); - high = std::min(high, param.rotary_embedding_dim - 1.f); - }; - float low, high; - find_correction_range(param.beta_fast, param.beta_slow, low, high); - if (low == high) { - high += 0.01f; + switch (type_) { + case RotaryScalingType::kDefault: + break; + case RotaryScalingType::kLinear: + inv_factor_ = 1.0f / param.rope.factor; + break; + case RotaryScalingType::kDynamic: + inv_factor_ = param.rope.factor; + break; + case RotaryScalingType::kYarn: { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param.rope.base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param.rope.dim - 1.f); + }; + float low, high; + find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; + yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; + yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor); + yarn_.attention_factor = param.rope.yarn.attention_factor; + break; + } + case RotaryScalingType::kLlama3: { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor); + llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor; + llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor; + break; } - yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0; - yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low; - yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor); - attention_factor_ = param.attention_factor; + default: + FT_CHECK(0); + break; } } @@ -188,7 +197,7 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - rope_scaling_factor_, + inv_factor_, cos_sin_); break; case RotaryScalingType::kLlama3: @@ -198,9 +207,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - llama3_inv_scaling_factor_, - llama3_alpha_, - llama3_beta_, + llama3_.llama3_inv_scaling_factor, + llama3_.llama3_alpha, + llama3_.llama3_beta, cos_sin_); break; case RotaryScalingType::kYarn: @@ -210,14 +219,12 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) params.token_num, params.batch_size, dim_, - yarn_ramp_inv_factor_div_2_, - yarn_ramp_inv_factor_mul_min_, - yarn_inv_scaling_factor_, - attention_factor_, + yarn_.yarn_ramp_inv_factor_div_2, + yarn_.yarn_ramp_inv_factor_mul_min, + yarn_.yarn_inv_scaling_factor, + yarn_.attention_factor, cos_sin_); break; - case RotaryScalingType::kMrope: - FT_CHECK(0); default: FT_CHECK(0); } diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index ffe81752e4..66a830b01a 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -5,15 +5,7 @@ namespace turbomind { -enum class RotaryScalingType -{ - kDefault, - kLinear, - kDynamic, - kYarn, - kLlama3, - kMrope -}; +RotaryScalingType GetRoPEType(const std::string& type); struct RotaryEmbeddingV2Params { float* rope_theta; @@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params { int token_num; }; +struct InnerYarnRopeParam { + float attention_factor; + float yarn_ramp_inv_factor_div_2; + float yarn_ramp_inv_factor_mul_min; + float yarn_inv_scaling_factor; +}; + +struct InnerLlama3RopeParam { + float llama3_inv_scaling_factor; + float llama3_alpha; + float llama3_beta; +}; + struct RotaryEmbeddingV2 { RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); @@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 { void forward(const RotaryEmbeddingV2Params& params); - RotaryScalingType type_; cudaStream_t const stream_; IAllocator* const allocator_; + int dim_; + RotaryScalingType type_; + float inv_factor_{1.0}; + + union { + InnerYarnRopeParam yarn_; + InnerLlama3RopeParam llama3_; + }; + // output float* cos_sin_; // num_token x dim, (cos, sin, ...) - - int dim_; - // default, linear, dynamic - float attention_factor_; - float rope_scaling_factor_; - float inv_scale_factor_; - // llama3 - float llama3_inv_scaling_factor_; - float llama3_alpha_; - float llama3_beta_; - // yarn - float yarn_ramp_inv_factor_div_2_; - float yarn_ramp_inv_factor_mul_min_; - float yarn_inv_scaling_factor_; - // mrope - int3 mrope_section_; }; }; // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 89224e853f..77d53afd5e 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -313,8 +313,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa } // rope - params.rotary_embedding_dim = param_.rotary_embedding_dim; - params.max_position_embeddings = param_.max_position_embeddings; + params.rotary_embedding_dim = param_.rope.dim; + params.max_position_embeddings = param_.rope.max_position_embeddings; params.cos_sin = cos_sin; params.use_logn_attn = param_.use_logn_attn; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index e40d7af22b..c37658fd37 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -88,11 +88,7 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_}); inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_}); inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_}); - - if (rotary_emb_) { - inputs.insert("cos_sin", - {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); - } + inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 40c5ac8907..f2fa583c90 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -140,10 +140,10 @@ void LlamaTritonModel::handleMissingParams() (int)model_param_.vocab_size); } - if (!attn_param_.max_position_embeddings) { - attn_param_.max_position_embeddings = 2048; + if (!attn_param_.rope.max_position_embeddings) { + attn_param_.rope.max_position_embeddings = 2048; TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.", - (int)attn_param_.max_position_embeddings); + (int)attn_param_.rope.max_position_embeddings); } if (!engine_param_.max_batch_size) { @@ -153,7 +153,7 @@ void LlamaTritonModel::handleMissingParams() } if (!engine_param_.session_len) { - engine_param_.session_len = attn_param_.max_position_embeddings; + engine_param_.session_len = attn_param_.rope.max_position_embeddings; TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len); } @@ -277,22 +277,25 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_param_.attn_bias = model_reader["attn_bias"].as(0); model_param_.group_size = model_reader["group_size"].as(0); + attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); + attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); // rotary embedding parameters - attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); - attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); - attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); - attn_param_.attention_factor = attention_reader["attention_factor"].as(-1.f); - attn_param_.beta_fast = attention_reader["beta_fast"].as(32.f); - attn_param_.beta_slow = attention_reader["beta_slow"].as(1.f); - attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as(""); - attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as(0.f); - attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); - attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); - attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); - attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as(0); - attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as(0); - - attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as(0); + attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as("")); + attn_param_.rope.dim = attention_reader["rotary_embedding"].as(); + attn_param_.rope.base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as(0); + attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as(0.f); + if (attn_param_.rope.type == RotaryScalingType::kYarn) { + attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as(-1.f); + attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as(32.f); + attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as(1.f); + } + else if (attn_param_.rope.type == RotaryScalingType::kLlama3) { + attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as(1.0); + attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as(1.0); + attn_param_.rope.llama3.original_max_position_embeddings = + attention_reader["original_max_position_embeddings"].as(0); + } engine_param_.max_batch_size = engine_reader["max_batch_size"].as(0); engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as(0);