diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index b6dfaa596c..f0f2e1af46 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -56,6 +56,7 @@ struct AttentionParams { int size_per_head; float inv_sqrt_dh; + float* cos_sin; // rotary embedding int rotary_embedding_dim; float rotary_embedding_base; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 5fb583bd1f..5fbeb75819 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -194,6 +194,8 @@ struct AttentionUniversal { Vec vec_K[1][ITER_C]; Vec vec_V[1][ITER_C]; + Array vec_cs[ITER_S][ITER_C]; // precomputed cos sin + const int2 offset = Map::get_offset(warp_id, lane_id); // Load Q @@ -217,12 +219,38 @@ struct AttentionUniversal { Ldg(vec_V[0][c], ¶ms.v[k_idx]); } } + if (params.cos_sin) { + float* cos_sin = params.cos_sin; + const int64_t index = qi * kHeadDim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } } } } ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset); + if (params.cos_sin) { + PrecomputeFastRoPE rope{}; + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + rope.apply(vec_Q[s][c], vec_cs[s][c]); + if constexpr (kProcessKV) { + if (s == 0) { + rope.apply(vec_K[0][c], vec_cs[s][c]); + } + } + } + } + } + +#if 0 const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base; PRAGMA_UNROLL for (int c = 0; c < ITER_C; ++c) { @@ -251,6 +279,7 @@ struct AttentionUniversal { } } } +#endif if (params.use_logn_attn) { PRAGMA_UNROLL diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 20bb00fde8..d552a63801 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -20,6 +20,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -124,6 +125,35 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } + if (cos_sin) { + Array vec_cs[ITER_S][ITER_C]; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + const int qi = offset.y + s * Map::kDeltaS + token_idx; + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + const int di = offset.x + c * Map::kDeltaC; + const int64_t index = (qi_beg + qi) * HeadDim + di; + PRAGMA_UNROLL + for (int k = 0; k < kVecSize; k += 4) { + if (qi < q_len) { + (float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]); + } + } + } + } + + PrecomputeFastRoPE rope; + PRAGMA_UNROLL + for (int s = 0; s < ITER_S; ++s) { + PRAGMA_UNROLL + for (int c = 0; c < ITER_C; ++c) { + rope.apply(vec_K[s][c], vec_cs[s][c]); + } + } + } + +#if 0 if (rope_base) { float base = rope_base[batch_idx]; PRAGMA_UNROLL @@ -149,6 +179,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } } +#endif Array param_K[ITER_S]; Array param_V[ITER_S]; @@ -211,6 +242,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -257,6 +289,7 @@ void invokeProcessKV_v2(char** blocks, cu_q_len, cu_k_len, cu_block_num, + cos_sin, rope_base, rope_dim, rope_ti_scale, @@ -306,6 +339,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, \ const int* cu_k_len, \ const int* cu_block_num, \ + const float* cos_sin, \ const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index fe45ad7be7..408310ba95 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -16,6 +16,7 @@ void invokeProcessKV_v2(char** blocks, const int* cu_q_len, const int* cu_k_len, const int* cu_block_num, + const float* cos_sin, const float* rope_base, int rope_dim, float rope_ti_scale, @@ -51,6 +52,7 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.cu_q_len, params.cu_k_len, params.block_iter_params.cu_block_nums, + params.cos_sin, params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index 8e09da22cd..8490275081 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -67,6 +67,21 @@ __device__ void ApplyRotaryEmbedding(Array& x, float base, int dims, int t } } +struct PrecomputeFastRoPE { + + template + __device__ void apply(Array& x, Array& cs) + { + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + float tmp0 = cs[i] * (float)x[i] - cs[i + 1] * (float)x[i + 1]; + float tmp1 = cs[i] * (float)x[i + 1] + cs[i + 1] * (float)x[i]; + x[i] = (T)tmp0; + x[i + 1] = (T)tmp1; + } + } +}; + template struct FastRoPE { diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index c6d7b40637..1d8e511030 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -147,6 +147,7 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, cu_seq_lens.data().get(), cu_block_cnts.data().get(), nullptr, + nullptr, rope_dim, 1., 0., diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 285fcea31f..10d3d36b0d 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_decoder_kernels.cu + rotary_emb.cu llama_utils.cu) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu new file mode 100644 index 0000000000..2ecec40a79 --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -0,0 +1,226 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/models/llama/rotary_emb.h" +#include + +namespace turbomind { + +__device__ int get_batch_id(int qi, int* q_len, int batch_size) +{ + int result{}; + int end = (batch_size + blockDim.x - 1) / blockDim.x * blockDim.x; + for (int i = threadIdx.x; i < end; i += blockDim.x) { + int prefix_sum = (i < batch_size) ? q_len[i + 1] : q_len[batch_size]; + auto count = __syncthreads_count(prefix_sum > qi); + if (count != 0) { + result = i / blockDim.x * blockDim.x + blockDim.x - count + 1; + break; + } + } + return result; +} + +__inline__ __device__ float compute_default_parameters(float base, float dim, int di, float factor) +{ + float scale_factor = -log2f(base) / dim; + float inv_freq = exp2f(di * scale_factor) * factor; + return inv_freq; +} + +__global__ void computeCosSinDefault(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float factor, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, factor); + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinLlama3(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float llama3_inv_scaling_factor, + float llama3_alpha, + float llama3_beta, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * inv_freq - llama3_beta)); + inv_freq = (1 - smooth) * inv_freq * llama3_inv_scaling_factor + smooth * inv_freq; + float c, s; + sincosf(ti * inv_freq, &s, &c); + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +__global__ void computeCosSinYarn(const float* rope_base, + int* q_len, + int* k_len, + int token_num, + int batch_size, + int dim, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, + float* cos_sin) +{ + int qi = blockIdx.x; + int di = threadIdx.x; + + int bid = get_batch_id(qi, q_len, batch_size); + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; + + float inv_freq = compute_default_parameters(base, dim, di * 2, 1.0f); + float alpha = 2 * di * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min; + alpha = fmaxf(0.f, fminf(1.f, alpha)); + inv_freq = inv_freq - inv_freq * alpha * yarn_inv_scaling_factor; + + float c, s; + sincosf(ti * inv_freq, &s, &c); + c *= attention_scaling; + s *= attention_scaling; + (float2&)cos_sin[dim * qi + 2 * di] = {c, s}; +} + +RotaryScalingType GetRoPEType(const std::string& type) +{ + std::map lookup = {{"", RotaryScalingType::kDefault}, + {"linear", RotaryScalingType::kLinear}, + {"dynamic", RotaryScalingType::kDynamic}, + {"yarn", RotaryScalingType::kYarn}, + {"llama3", RotaryScalingType::kLlama3}, + {"mrope", RotaryScalingType::kMrope}}; + return lookup.at(type); +} + +void RotaryEmbeddingV2::freeBuffer() +{ + allocator_->free((void**)&cos_sin_); +} + +void RotaryEmbeddingV2::allocateBuffer(size_t token_num) +{ + cos_sin_ = (float*)allocator_->reMalloc(cos_sin_, sizeof(float) * token_num * dim_); +} + +RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator): + stream_(stream), allocator_(allocator) +{ + type_ = GetRoPEType(param.rope_scaling_type); + dim_ = param.rotary_embedding_dim; + rope_scaling_factor_ = 1.0f; + attention_factor_ = 1.0f; + + if (type_ == RotaryScalingType::kLinear) { + rope_scaling_factor_ /= param.rope_scaling_factor; + } + else if (type_ == RotaryScalingType::kLlama3) { + const double PI = 3.14159265358979323846; + float inv_diff_freq_factor = 1.0 / (param.high_freq_factor - param.low_freq_factor); + llama3_inv_scaling_factor_ = 1.0 / param.rope_scaling_factor; + llama3_alpha_ = param.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; + llama3_beta_ = param.low_freq_factor * inv_diff_freq_factor; + } + else if (type_ == RotaryScalingType::kYarn) { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param.rotary_embedding_dim * std::log(param.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param.rotary_embedding_base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param.rotary_embedding_dim - 1.f); + }; + float low, high; + find_correction_range(param.beta_fast, param.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + yarn_ramp_inv_factor_div_2_ = 1.0 / (high - low) / 2.0; + yarn_ramp_inv_factor_mul_min_ = 1.0 / (high - low) * low; + yarn_inv_scaling_factor_ = (1 - 1.0 / param.rope_scaling_factor); + attention_factor_ = param.attention_factor; + } +} + +void RotaryEmbeddingV2::forward(const RotaryEmbeddingV2Params& params) +{ + allocateBuffer(params.token_num); + + const int grid = params.token_num; + const int block = dim_ / 2; + + switch (type_) { + case RotaryScalingType::kDefault: + case RotaryScalingType::kLinear: + case RotaryScalingType::kDynamic: + computeCosSinDefault<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + rope_scaling_factor_, + cos_sin_); + break; + case RotaryScalingType::kLlama3: + computeCosSinLlama3<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + llama3_inv_scaling_factor_, + llama3_alpha_, + llama3_beta_, + cos_sin_); + break; + case RotaryScalingType::kYarn: + computeCosSinYarn<<>>(params.rope_theta, + params.q_len, + params.k_ken, + params.token_num, + params.batch_size, + dim_, + yarn_ramp_inv_factor_div_2_, + yarn_ramp_inv_factor_mul_min_, + yarn_inv_scaling_factor_, + attention_factor_, + cos_sin_); + break; + case RotaryScalingType::kMrope: + FT_CHECK(0); + default: + FT_CHECK(0); + } +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h new file mode 100644 index 0000000000..ffe81752e4 --- /dev/null +++ b/src/turbomind/models/llama/rotary_emb.h @@ -0,0 +1,65 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#pragma once +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/allocator.h" + +namespace turbomind { + +enum class RotaryScalingType +{ + kDefault, + kLinear, + kDynamic, + kYarn, + kLlama3, + kMrope +}; + +struct RotaryEmbeddingV2Params { + float* rope_theta; + int* q_len; + int* k_ken; + int batch_size; + int token_num; +}; + +struct RotaryEmbeddingV2 { + + RotaryEmbeddingV2(const AttentionParam& param, cudaStream_t stream, IAllocator* allocator); + + void freeBuffer(); + + void allocateBuffer(size_t token_num); + + ~RotaryEmbeddingV2() + { + freeBuffer(); + } + + void forward(const RotaryEmbeddingV2Params& params); + + RotaryScalingType type_; + cudaStream_t const stream_; + IAllocator* const allocator_; + + // output + float* cos_sin_; // num_token x dim, (cos, sin, ...) + + int dim_; + // default, linear, dynamic + float attention_factor_; + float rope_scaling_factor_; + float inv_scale_factor_; + // llama3 + float llama3_inv_scaling_factor_; + float llama3_alpha_; + float llama3_beta_; + // yarn + float yarn_ramp_inv_factor_div_2_; + float yarn_ramp_inv_factor_mul_min_; + float yarn_inv_scaling_factor_; + // mrope + int3 mrope_section_; +}; + +}; // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 2f99b0c2ce..c80c1cb6b9 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -187,6 +187,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa bool* is_finished = inputs->getPtr("finished"); float* rope_theta = inputs->getPtr("rope_theta"); + float* cos_sin = inputs->at("cos_sin", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); + void** block_ptrs = outputs->getPtr("block_ptrs"); int* cu_block_count = inputs->getPtr("cu_block_counts"); @@ -338,6 +340,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa } } + params.cos_sin = cos_sin; + params.use_logn_attn = param_.use_logn_attn; // Decoding use only for now diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 68392215f6..cc37743878 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -29,6 +29,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, attn_layer_ = std::make_unique>(model, attn, lora, tp, ctx); ffn_layer_ = std::make_unique>(model, tp, ctx, true); moe_ffn_layer_ = std::make_unique>(model, moe, tp, ctx); + rotary_emb_ = std::make_unique(attn, ctx.stream, ctx.allocator.get()); check_cuda_error(cudaEventCreateWithFlags(&ev_h_cu_x_, cudaEventDisableTiming)); } @@ -75,6 +76,11 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, inputs.insert("h_cu_q_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_q_len_}); inputs.insert("h_cu_k_len", {MEMORY_CPU, TYPE_INT32, {batch_size + 1}, h_cu_k_len_}); + if (rotary_emb_) { + inputs.insert("cos_sin", + {MEMORY_GPU, TYPE_FP32, {token_num, (size_t)rotary_emb_->dim_}, rotary_emb_->cos_sin_}); + } + TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); @@ -152,6 +158,16 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); + { + RotaryEmbeddingV2Params params; + params.rope_theta = inputs->getPtr("rope_theta"); + params.q_len = cu_q_len_; + params.k_ken = cu_k_len_; + params.batch_size = batch_size; + params.token_num = token_num; + rotary_emb_->forward(params); + } + for (size_t layer = 0; layer < layer_num_; ++layer) { /// TODO: do not skip the layers when they are heterogeneous diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index f13b4ba842..e18da41c38 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -5,6 +5,7 @@ #include "src/turbomind/models/llama/context.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/moe_ffn_layer.h" +#include "src/turbomind/models/llama/rotary_emb.h" #include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" @@ -34,6 +35,7 @@ class UnifiedDecoder { std::unique_ptr> attn_layer_; std::unique_ptr> ffn_layer_; std::unique_ptr> moe_ffn_layer_; + std::unique_ptr rotary_emb_; cudaEvent_t ev_h_cu_x_{};