diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index f0f2e1af46..7bc770ab64 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -56,23 +56,10 @@ struct AttentionParams { int size_per_head; float inv_sqrt_dh; - float* cos_sin; // rotary embedding - int rotary_embedding_dim; - float rotary_embedding_base; - float rope_scaling_factor; - float attention_scaling; - int max_position_embeddings; - float rope_ti_scale; // used for linear RoPE scaling - // the following 3 parameters are used by llama3 - float llama3_inv_scaling_factor; - float llama3_alpha; - float llama3_beta; - // the following are use by yarn - float yarn_ramp_inv_factor_div_2; - float yarn_ramp_inv_factor_mul_min; - float yarn_inv_scaling_factor; - + float* cos_sin; + int rotary_embedding_dim; + int max_position_embeddings; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fbeb75819..67b72f2004 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -221,7 +221,7 @@ struct AttentionUniversal { } if (params.cos_sin) { float* cos_sin = params.cos_sin; - const int64_t index = qi * kHeadDim + di; + const int64_t index = qi * params.rotary_embedding_dim + di; PRAGMA_UNROLL for (int k = 0; k < kVecSize; k += 4) { (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); @@ -236,51 +236,22 @@ struct AttentionUniversal { if (params.cos_sin) { PrecomputeFastRoPE rope{}; PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; + for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - rope.apply(vec_Q[s][c], vec_cs[s][c]); - if constexpr (kProcessKV) { - if (s == 0) { - rope.apply(vec_K[0][c], vec_cs[s][c]); + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + if (di < params.rotary_embedding_dim) { + rope.apply(vec_Q[s][c], vec_cs[s][c]); + if constexpr (kProcessKV) { + if (s == 0) { + rope.apply(vec_K[0][c], vec_cs[s][c]); + } } } } } } -#if 0 - const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - params.rotary_embedding_dim, - rope_base, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len; - rope.apply(vec_Q[s][c], ti); - if constexpr (kProcessKV) { - if (s == 0) { - rope.apply(vec_K[0][c], ti); - } - } - } - } -#endif - if (params.use_logn_attn) { PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index d552a63801..5f2a5ea6e6 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -21,17 +21,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -133,7 +123,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { const int di = offset.x + c * Map::kDeltaC; - const int64_t index = (qi_beg + qi) * HeadDim + di; + const int64_t index = (qi_beg + qi) * rope_dim + di; PRAGMA_UNROLL for (int k = 0; k < kVecSize; k += 4) { if (qi < q_len) { @@ -148,38 +138,13 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { - rope.apply(vec_K[s][c], vec_cs[s][c]); - } - } - } - -#if 0 - if (rope_base) { - float base = rope_base[batch_idx]; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = history_len + offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(vec_K[s][c], ti); + const int di = offset.x + c * Map::kDeltaC; + if (di < rope_dim) { + rope.apply(vec_K[s][c], vec_cs[s][c]); + } } } } -#endif Array param_K[ITER_S]; Array param_V[ITER_S]; @@ -243,17 +208,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_1_alpha, - float llama3_1_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -290,17 +245,7 @@ void invokeProcessKV_v2(char** blocks, cu_k_len, cu_block_num, cos_sin, - rope_base, rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_1_alpha, - llama3_1_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -340,17 +285,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, \ const int* cu_block_num, \ const float* cos_sin, \ - const float* rope_base, \ int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_1_alpha, \ - float llama3_1_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -370,28 +305,17 @@ INSTANTIATE_invokeProcessKV_v2(nv_bfloat16); #endif template -__global__ void __launch_bounds__(128) flattenKV_v2(T* k, - T* v, - const Tkv** blocks, - const int* cu_k_len, - const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, - int64_t stride_b, - int64_t stride_c, - int64_t stride_h, - int64_t stride_s, - int layer_id, - BlockLayout block_layout) +__global__ void __launch_bounds__(128) flattenKV_v2(T* k, + T* v, + const Tkv** blocks, + const int* cu_k_len, + const int* cu_block_num, + int64_t stride_b, + int64_t stride_c, + int64_t stride_h, + int64_t stride_s, + int layer_id, + BlockLayout block_layout) { constexpr int kVecSize = sizeof(uint4) / sizeof(T); @@ -462,32 +386,6 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, } } - if (rope_base) { - float base = rope_base[batch_idx]; - PRAGMA_UNROLL - for (int c = 0; c < ITER_C; ++c) { - const int di = offset.x + c * Map::kDeltaC; - FastRoPE rope(di, - rope_dim, - base, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, - std::integral_constant{}); - PRAGMA_UNROLL - for (int s = 0; s < ITER_S; ++s) { - const int ti = offset.y + s * Map::kDeltaS + token_idx; // sequence local - rope.apply(out_K[s][c], ti); - } - } - } - PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL @@ -510,17 +408,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -553,17 +440,6 @@ void invokeFlattenKV_v2(T* k, (const Tkv**)blocks, cu_k_len, cu_block_num, - rope_base, - rope_dim, - rope_ti_scale, - rope_scaling_factor, - llama3_inv_scaling_factor, - llama3_alpha, - llama3_beta, - yarn_ramp_inv_factor_div_2, - yarn_ramp_inv_factor_mul_min, - yarn_inv_scaling_factor, - attention_scaling, stride_b, stride_c, stride_h, @@ -599,17 +475,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, \ const int* cu_k_len, \ const int* cu_block_num, \ - const float* rope_base, \ - int rope_dim, \ - float rope_ti_scale, \ - float rope_scaling_factor, \ - float llama3_inv_scaling_factor, \ - float llama3_alpha, \ - float llama3_beta, \ - float yarn_ramp_inv_factor_div_2, \ - float yarn_ramp_inv_factor_mul_min, \ - float yarn_inv_scaling_factor, \ - float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 408310ba95..0f789fb58a 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -17,17 +17,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_k_len, const int* cu_block_num, const float* cos_sin, - const float* rope_base, int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -53,17 +43,7 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.cu_k_len, params.block_iter_params.cu_block_nums, params.cos_sin, - params.rope_theta, params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -84,17 +64,6 @@ void invokeFlattenKV_v2(T* k, char** blocks, const int* cu_k_len, const int* cu_block_num, - const float* rope_base, - int rope_dim, - float rope_ti_scale, - float rope_scaling_factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -118,17 +87,6 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) (char**)params.block_iter_params.block_ptrs, params.cu_k_len, params.block_iter_params.cu_block_nums, - nullptr, // params.rope_theta, - params.rotary_embedding_dim, - params.rope_ti_scale, - params.rope_scaling_factor, - params.llama3_inv_scaling_factor, - params.llama3_alpha, - params.llama3_beta, - params.yarn_ramp_inv_factor_div_2, - params.yarn_ramp_inv_factor_mul_min, - params.yarn_inv_scaling_factor, - params.attention_scaling, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 1d8e511030..48012a4c09 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -147,17 +147,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, - nullptr, rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -181,17 +171,6 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, k_ptrs.data().get(), cu_seq_lens.data().get(), cu_block_cnts.data().get(), - nullptr, - rope_dim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -436,9 +415,7 @@ int test_attention() params.size_per_head = kHeadDim; params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); - params.rotary_embedding_dim = kRoPEDim; - params.rotary_embedding_base = kRoPEBase; - params.rope_ti_scale = 1.; + params.rotary_embedding_dim = kRoPEDim; params.split_cnt = split_cnt.data().get(); params.partial_L = partial_L.data().get(); @@ -545,17 +522,6 @@ int test_attention() k_ptrs.data().get(), cu_kv_lens.data().get(), cu_block_cnts.data().get(), - nullptr, // DECODING ? nullptr : params.rope_theta, - kRoPEDim, - 1., - 0., - 0., - 1.0, - 1.0, - 0.0, - 0.0, - 0.0, - 1.0, KvHeadNum * kContextLen, 0, kContextLen, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index c80c1cb6b9..9e3b704132 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -295,54 +295,11 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa // MSVC does not have M_LOG2E params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); + // rope params.rotary_embedding_dim = param_.rotary_embedding_dim; - params.rotary_embedding_base = param_.rotary_embedding_base; params.max_position_embeddings = param_.max_position_embeddings; - params.rope_scaling_factor = param_.rope_scaling_factor; - params.attention_scaling = 1.0; - params.rope_ti_scale = 1.f; - if (param_.rope_scaling_type == "linear") { - params.rope_ti_scale /= param_.rope_scaling_factor; - } - if (param_.rope_scaling_type == "llama3") { - const double PI = 3.14159265358979323846; - float inv_diff_freq_factor = 1.0 / (param_.high_freq_factor - param_.low_freq_factor); - params.llama3_inv_scaling_factor = 1.0 / param_.rope_scaling_factor; - params.llama3_alpha = param_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; - params.llama3_beta = param_.low_freq_factor * inv_diff_freq_factor; - } - if (param_.rope_scaling_type == "yarn") { - const double PI = 3.14159265358979323846; - auto find_correction_dim = [&](float num_rotations) { - return (param_.rotary_embedding_dim - * std::log(param_.max_position_embeddings / (num_rotations * 2 * PI))) - / (2 * std::log(param_.rotary_embedding_base)); - }; - auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { - low = std::floor(find_correction_dim(low_rot)); - high = std::ceil(find_correction_dim(high_rot)); - low = std::max(low, 0.f); - high = std::min(high, param_.rotary_embedding_dim - 1.f); - }; - float low, high; - find_correction_range(param_.beta_fast, param_.beta_slow, low, high); - if (low == high) { - high += 0.01f; - } - params.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; - params.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; - params.yarn_inv_scaling_factor = (1 - 1.0 / param_.rope_scaling_factor); - if (param_.attention_factor < 0) { - params.attention_scaling = 0.1 * std::log(param_.rope_scaling_factor) + 1.0; - } - else { - params.attention_scaling = param_.attention_factor; - } - } - - params.cos_sin = cos_sin; - - params.use_logn_attn = param_.use_logn_attn; + params.cos_sin = cos_sin; + params.use_logn_attn = param_.use_logn_attn; // Decoding use only for now FT_CHECK(barriers_);