Skip to content

Commit

Permalink
split rope params
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 2, 2024
1 parent 45f0968 commit 0e4c315
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 118 deletions.
10 changes: 5 additions & 5 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,15 +368,15 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)

// compute rope scaling factor
if (r->start_flag) {
seq.rope_theta = model_->attn_param_.rotary_embedding_base;
if (model_->attn_param_.use_dynamic_ntk) {
auto scaling_factor = model_->attn_param_.rope_scaling_factor;
seq.rope_theta = model_->attn_param_.rope.base;
if (model_->attn_param_.rope.type == RotaryScalingType::kDynamic) {
auto scaling_factor = model_->attn_param_.rope.factor;
if (scaling_factor >= 1.f) { // infer by current context length
auto max_seq_len = state.h_context_length[idx];
auto max_pos_emb = model_->attn_param_.max_position_embeddings;
auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings;
if (max_seq_len > max_pos_emb) {
scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
float rope_dim = model_->attn_param_.rotary_embedding_dim;
float rope_dim = model_->attn_param_.rope.dim;
seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
(long)seq.id,
Expand Down
53 changes: 38 additions & 15 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,45 @@ struct MoeParam {
std::vector<int> expert_num;
};

enum class RotaryScalingType
{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
};

struct YarnRopeParam {
float attention_factor;
float beta_fast;
float beta_slow;
};

struct Llama3RopeParam {
float low_freq_factor;
float high_freq_factor;
int original_max_position_embeddings;
};

struct AttentionParam {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
float softmax_scale;
std::string rope_scaling_type;
int original_max_position_embeddings;
float rope_scaling_factor;
float low_freq_factor;
float high_freq_factor;
float attention_factor;
float beta_fast;
float beta_slow;
bool use_dynamic_ntk;
bool use_logn_attn;
int cache_block_seq_len;
float softmax_scale;
int cache_block_seq_len;
bool use_logn_attn;
// rope
struct {
// common
RotaryScalingType type;
int dim;
float base;
float factor;
int max_position_embeddings;
// special
union {
YarnRopeParam yarn;
Llama3RopeParam llama3;
};
} rope;
};

struct EngineParam {
Expand Down
99 changes: 53 additions & 46 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ RotaryScalingType GetRoPEType(const std::string& type)
{"linear", RotaryScalingType::kLinear},
{"dynamic", RotaryScalingType::kDynamic},
{"yarn", RotaryScalingType::kYarn},
{"llama3", RotaryScalingType::kLlama3},
{"mrope", RotaryScalingType::kMrope}};
{"llama3", RotaryScalingType::kLlama3}};
return lookup.at(type);
}

Expand All @@ -132,42 +131,52 @@ void RotaryEmbeddingV2::allocateBuffer(size_t token_num)
RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator):
stream_(stream), allocator_(allocator)
{
type_ = GetRoPEType(param.rope_scaling_type);
dim_ = param.rotary_embedding_dim;
rope_scaling_factor_ = 1.0f;
attention_factor_ = 1.0f;
type_ = param.rope.type;
dim_ = param.rope.dim;

if (type_ == RotaryScalingType::kLinear) {
rope_scaling_factor_ /= param.rope_scaling_factor;
}
else if (type_ == RotaryScalingType::kLlama3) {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor);
llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor;
llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor;
}
else if (type_ == RotaryScalingType::kYarn) {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rotary_embedding_base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rotary_embedding_dim - 1.f);
};
float low, high;
find_correction_range(param.beta_fast, param.beta_slow, low, high);
if (low == high) {
high += 0.01f;
switch (type_) {
case RotaryScalingType::kDefault:
break;
case RotaryScalingType::kLinear:
inv_factor_ = 1.0f / param.rope.factor;
break;
case RotaryScalingType::kDynamic:
inv_factor_ = param.rope.factor;
break;
case RotaryScalingType::kYarn: {
const double PI = 3.14159265358979323846;
auto find_correction_dim = [&](float num_rotations) {
return (param.rope.dim * std::log(param.rope.max_position_embeddings / (num_rotations * 2 * PI)))
/ (2 * std::log(param.rope.base));
};
auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) {
low = std::floor(find_correction_dim(low_rot));
high = std::ceil(find_correction_dim(high_rot));
low = std::max(low, 0.f);
high = std::min(high, param.rope.dim - 1.f);
};
float low, high;
find_correction_range(param.rope.yarn.beta_fast, param.rope.yarn.beta_slow, low, high);
if (low == high) {
high += 0.01f;
}
yarn_.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0;
yarn_.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low;
yarn_.yarn_inv_scaling_factor = (1 - 1.0 / param.rope.factor);
yarn_.attention_factor = param.rope.yarn.attention_factor;
break;
}
case RotaryScalingType::kLlama3: {
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor);
llama3_.llama3_inv_scaling_factor = 1.0 / param.rope.factor;
llama3_.llama3_alpha = param.rope.llama3.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor;
llama3_.llama3_beta = param.rope.llama3.low_freq_factor * inv_diff_freq_factor;
break;
}
yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0;
yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low;
yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor);
attention_factor_ = param.attention_factor;
default:
FT_CHECK(0);
break;
}
}

Expand All @@ -188,7 +197,7 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
rope_scaling_factor_,
inv_factor_,
cos_sin_);
break;
case RotaryScalingType::kLlama3:
Expand All @@ -198,9 +207,9 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
llama3_inv_scaling_factor_,
llama3_alpha_,
llama3_beta_,
llama3_.llama3_inv_scaling_factor,
llama3_.llama3_alpha,
llama3_.llama3_beta,
cos_sin_);
break;
case RotaryScalingType::kYarn:
Expand All @@ -210,14 +219,12 @@ void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params)
params.token_num,
params.batch_size,
dim_,
yarn_ramp_inv_factor_div_2_,
yarn_ramp_inv_factor_mul_min_,
yarn_inv_scaling_factor_,
attention_factor_,
yarn_.yarn_ramp_inv_factor_div_2,
yarn_.yarn_ramp_inv_factor_mul_min,
yarn_.yarn_inv_scaling_factor,
yarn_.attention_factor,
cos_sin_);
break;
case RotaryScalingType::kMrope:
FT_CHECK(0);
default:
FT_CHECK(0);
}
Expand Down
49 changes: 23 additions & 26 deletions src/turbomind/models/llama/rotary_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

namespace turbomind {

enum class RotaryScalingType
{
kDefault,
kLinear,
kDynamic,
kYarn,
kLlama3,
kMrope
};
RotaryScalingType GetRoPEType(const std::string& type);

struct RotaryEmbeddingV2Params {
float* rope_theta;
Expand All @@ -23,6 +15,19 @@ struct RotaryEmbeddingV2Params {
int token_num;
};

struct InnerYarnRopeParam {
float attention_factor;
float yarn_ramp_inv_factor_div_2;
float yarn_ramp_inv_factor_mul_min;
float yarn_inv_scaling_factor;
};

struct InnerLlama3RopeParam {
float llama3_inv_scaling_factor;
float llama3_alpha;
float llama3_beta;
};

struct RotaryEmbeddingV2 {

RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator);
Expand All @@ -38,28 +43,20 @@ struct RotaryEmbeddingV2 {

void forward(const RotaryEmbeddingV2Params& params);

RotaryScalingType type_;
cudaStream_t const stream_;
IAllocator* const allocator_;

int dim_;
RotaryScalingType type_;
float inv_factor_{1.0};

union {
InnerYarnRopeParam yarn_;
InnerLlama3RopeParam llama3_;
};

// output
float* cos_sin_; // num_token x dim, (cos, sin, ...)

int dim_;
// default, linear, dynamic
float attention_factor_;
float rope_scaling_factor_;
float inv_scale_factor_;
// llama3
float llama3_inv_scaling_factor_;
float llama3_alpha_;
float llama3_beta_;
// yarn
float yarn_ramp_inv_factor_div_2_;
float yarn_ramp_inv_factor_mul_min_;
float yarn_inv_scaling_factor_;
// mrope
int3 mrope_section_;
};

}; // namespace turbomind
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/unified_attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ inline void UnifiedAttentionLayer<T>::forward(TensorMap* outputs, const TensorMa
}

// rope
params.rotary_embedding_dim = param_.rotary_embedding_dim;
params.max_position_embeddings = param_.max_position_embeddings;
params.rotary_embedding_dim = param_.rope.dim;
params.max_position_embeddings = param_.rope.max_position_embeddings;
params.cos_sin = cos_sin;
params.use_logn_attn = param_.use_logn_attn;

Expand Down
6 changes: 1 addition & 5 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ void UnifiedDecoder<T>::forwardSelfAttn(T* attn_io,
inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_});
inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_});
inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_});

if (rotary_emb_) {
inputs.insert("cos_sin",
{MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});
}
inputs.insert("cos_sin", {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_});

TensorMap outputs(*_outputs);
outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io});
Expand Down
41 changes: 22 additions & 19 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ void LlamaTritonModel<T>::handleMissingParams()
(int)model_param_.vocab_size);
}

if (!attn_param_.max_position_embeddings) {
attn_param_.max_position_embeddings = 2048;
if (!attn_param_.rope.max_position_embeddings) {
attn_param_.rope.max_position_embeddings = 2048;
TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to %d.",
(int)attn_param_.max_position_embeddings);
(int)attn_param_.rope.max_position_embeddings);
}

if (!engine_param_.max_batch_size) {
Expand All @@ -153,7 +153,7 @@ void LlamaTritonModel<T>::handleMissingParams()
}

if (!engine_param_.session_len) {
engine_param_.session_len = attn_param_.max_position_embeddings;
engine_param_.session_len = attn_param_.rope.max_position_embeddings;
TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)engine_param_.session_len);
}

Expand Down Expand Up @@ -277,22 +277,25 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_param_.attn_bias = model_reader["attn_bias"].as<int>(0);
model_param_.group_size = model_reader["group_size"].as<int>(0);

attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);
// rotary embedding parameters
attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.softmax_scale = attention_reader["softmax_scale"].as<float>(0);
attn_param_.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as<std::string>("");
attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.use_dynamic_ntk = attention_reader["use_dynamic_ntk"].as<int>(0);
attn_param_.use_logn_attn = attention_reader["use_logn_attn"].as<int>(0);

attn_param_.original_max_position_embeddings = attention_reader["original_max_position_embeddings"].as<int>(0);
attn_param_.rope.type = GetRoPEType(attention_reader["rope_scaling_type"].as<std::string>(""));
attn_param_.rope.dim = attention_reader["rotary_embedding"].as<int>();
attn_param_.rope.base = attention_reader["rope_theta"].as<float>(10000.0f);
attn_param_.rope.max_position_embeddings = attention_reader["max_position_embeddings"].as<int>(0);
attn_param_.rope.factor = attention_reader["rope_scaling_factor"].as<float>(0.f);
if (attn_param_.rope.type == RotaryScalingType::kYarn) {
attn_param_.rope.yarn.attention_factor = attention_reader["attention_factor"].as<float>(-1.f);
attn_param_.rope.yarn.beta_fast = attention_reader["beta_fast"].as<float>(32.f);
attn_param_.rope.yarn.beta_slow = attention_reader["beta_slow"].as<float>(1.f);
}
else if (attn_param_.rope.type == RotaryScalingType::kLlama3) {
attn_param_.rope.llama3.low_freq_factor = attention_reader["low_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.high_freq_factor = attention_reader["high_freq_factor"].as<float>(1.0);
attn_param_.rope.llama3.original_max_position_embeddings =
attention_reader["original_max_position_embeddings"].as<int>(0);
}

engine_param_.max_batch_size = engine_reader["max_batch_size"].as<int>(0);
engine_param_.max_prefill_token_num = engine_reader["max_prefill_token_num"].as<int>(0);
Expand Down

0 comments on commit 0e4c315

Please sign in to comment.