Skip to content

Commit

Permalink
use precomputed cos sin
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Nov 25, 2024
1 parent 67a8538 commit 05d011c
Show file tree
Hide file tree
Showing 12 changed files with 396 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct AttentionParams {
int size_per_head;
float inv_sqrt_dh;

float* cos_sin;
// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
Expand Down
29 changes: 29 additions & 0 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ struct AttentionUniversal {
Vec vec_K[1][ITER_C];
Vec vec_V[1][ITER_C];

Array<float, kVecSize> vec_cs[ITER_S][ITER_C]; // precomputed cos sin

const int2 offset = Map::get_offset(warp_id, lane_id);

// Load Q
Expand All @@ -217,12 +219,38 @@ struct AttentionUniversal {
Ldg(vec_V[0][c], &params.v[k_idx]);
}
}
if (params.cos_sin) {
float* cos_sin = params.cos_sin;
const int64_t index = qi * kHeadDim + di;
PRAGMA_UNROLL
for (int k = 0; k < kVecSize; k += 4) {
(float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]);
}
}
}
}
}

ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset);

if (params.cos_sin) {
PrecomputeFastRoPE rope{};
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
rope.apply(vec_Q[s][c], vec_cs[s][c]);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], vec_cs[s][c]);
}
}
}
}
}

#if 0
const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base;
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
Expand Down Expand Up @@ -251,6 +279,7 @@ struct AttentionUniversal {
}
}
}
#endif

if (params.use_logn_attn) {
PRAGMA_UNROLL
Expand Down
34 changes: 34 additions & 0 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
const int* cu_q_len,
const int* cu_k_len,
const int* cu_block_num,
const float* cos_sin,
const float* rope_base,
int rope_dim,
float rope_ti_scale,
Expand Down Expand Up @@ -124,6 +125,35 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
}
}

if (cos_sin) {
Array<float, kVecSize> vec_cs[ITER_S][ITER_C];
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int qi = offset.y + s * Map::kDeltaS + token_idx;
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
const int64_t index = (qi_beg + qi) * HeadDim + di;
PRAGMA_UNROLL
for (int k = 0; k < kVecSize; k += 4) {
if (qi < q_len) {
(float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]);
}
}
}
}

PrecomputeFastRoPE rope;
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
rope.apply(vec_K[s][c], vec_cs[s][c]);
}
}
}

#if 0
if (rope_base) {
float base = rope_base[batch_idx];
PRAGMA_UNROLL
Expand All @@ -149,6 +179,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
}
}
}
#endif

Array<T, 2> param_K[ITER_S];
Array<T, 2> param_V[ITER_S];
Expand Down Expand Up @@ -211,6 +242,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_q_len,
const int* cu_k_len,
const int* cu_block_num,
const float* cos_sin,
const float* rope_base,
int rope_dim,
float rope_ti_scale,
Expand Down Expand Up @@ -257,6 +289,7 @@ void invokeProcessKV_v2(char** blocks,
cu_q_len,
cu_k_len,
cu_block_num,
cos_sin,
rope_base,
rope_dim,
rope_ti_scale,
Expand Down Expand Up @@ -306,6 +339,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_q_len, \
const int* cu_k_len, \
const int* cu_block_num, \
const float* cos_sin, \
const float* rope_base, \
int rope_dim, \
float rope_ti_scale, \
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_q_len,
const int* cu_k_len,
const int* cu_block_num,
const float* cos_sin,
const float* rope_base,
int rope_dim,
float rope_ti_scale,
Expand Down Expand Up @@ -51,6 +52,7 @@ void invokeProcessKV_v2_(const AttentionParams<T>& params)
params.cu_q_len,
params.cu_k_len,
params.block_iter_params.cu_block_nums,
params.cos_sin,
params.rope_theta,
params.rotary_embedding_dim,
params.rope_ti_scale,
Expand Down
15 changes: 15 additions & 0 deletions src/turbomind/kernels/attention/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ __device__ void ApplyRotaryEmbedding(Array<T, 4>& x, float base, int dims, int t
}
}

struct PrecomputeFastRoPE {

template<typename T, int N>
__device__ void apply(Array<T, N>& x, Array<float, N>& cs)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float tmp0 = cs[i] * (float)x[i] - cs[i + 1] * (float)x[i + 1];
float tmp1 = cs[i] * (float)x[i + 1] + cs[i + 1] * (float)x[i];
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
};

template<class D, int N>
struct FastRoPE {

Expand Down
1 change: 1 addition & 0 deletions src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S,
cu_seq_lens.data().get(),
cu_block_cnts.data().get(),
nullptr,
nullptr,
rope_dim,
1.,
0.,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(Llama STATIC
unified_attention_layer.cc
llama_kernels.cu
llama_decoder_kernels.cu
rotary_emb.cu
llama_utils.cu)
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Expand Down
Loading

0 comments on commit 05d011c

Please sign in to comment.