Skip to content

Commit

Permalink
compute sin/cos in advance
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 10, 2024
1 parent d1eb613 commit 795c56f
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 146 deletions.
7 changes: 4 additions & 3 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ struct AttentionParams {
float inv_sqrt_dh;

// rotary embedding
T* cos_sin;
int rotary_embedding_dim;
int max_position_embeddings;
T* cos_sin;
int* q2p;
int rotary_embedding_dim;
int max_position_embeddings;
// log(n) attention
bool use_logn_attn;

Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ struct AttentionUniversal {
}
if (params.cos_sin) {
T* cos_sin = params.cos_sin;
const int64_t index = qi * params.rotary_embedding_dim + di;
const int64_t index = params.q2p[qi] * params.rotary_embedding_dim + di;
Ldg(vec_cs[s][c], &cos_sin[index]);
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/turbomind/kernels/attention/kv_cache_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
const int* cu_k_len,
const int* cu_block_num,
const T* cos_sin,
const int* q2p,
int rope_dim,
int64_t stride_b,
int64_t stride_c,
Expand Down Expand Up @@ -123,7 +124,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
const int64_t index = (qi_beg + qi) * rope_dim + di;
const int64_t index = q2p[qi_beg + qi] * rope_dim + di;
if (qi < q_len) {
Ldg(vec_cs[s][c], &cos_sin[index]);
}
Expand Down Expand Up @@ -205,6 +206,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_k_len,
const int* cu_block_num,
const T* cos_sin,
const int* q2p,
int rope_dim,
int64_t stride_b,
int64_t stride_c,
Expand Down Expand Up @@ -242,6 +244,7 @@ void invokeProcessKV_v2(char** blocks,
cu_k_len,
cu_block_num,
cos_sin,
q2p,
rope_dim,
stride_b,
stride_c,
Expand Down Expand Up @@ -285,6 +288,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_k_len, \
const int* cu_block_num, \
const type* cos_sin, \
const int* q2p, \
int rope_dim, \
int64_t stride_b, \
int64_t stride_c, \
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/kernels/attention/kv_cache_utils_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void invokeProcessKV_v2(char** blocks,
const int* cu_k_len,
const int* cu_block_num,
const T* cos_sin,
const int* q2p,
int rope_dim,
int64_t stride_b,
int64_t stride_c,
Expand All @@ -43,6 +44,7 @@ void invokeProcessKV_v2_(const AttentionParams<T>& params)
params.cu_k_len,
params.block_iter_params.cu_block_nums,
params.cos_sin,
params.q2p,
params.rotary_embedding_dim,
0, // stride b
params.stride / params.size_per_head, // stride c
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S,
cu_seq_lens.data().get(),
cu_block_cnts.data().get(),
(T*)nullptr,
(int*)nullptr,
rope_dim,
2 * head_num * seq_len,
0,
Expand Down Expand Up @@ -386,16 +387,19 @@ int test_attention()
attn_param.rope.base = kRoPEBase;
attn_param.rope.dim = kRoPEDim;
attn_param.rope.factor = 1.0f;
auto rotary_emb = std::make_unique<RotaryEmbeddingV2<T>>(attn_param, nullptr, allocator.get());
auto rotary_emb = std::make_unique<RotaryEmbeddingV2<T>>(attn_param, kInputLen, nullptr, allocator.get());

RotaryEmbeddingV2Param rotary_param;
rotary_param.rope_theta = rope_base.data().get();
rotary_param.q_len = cu_seqlens.data().get();
rotary_param.k_ken = cu_kv_lens.data().get();
rotary_param.k_len = cu_kv_lens.data().get();
rotary_param.h_q_len = cu_seqlens.data().get();
rotary_param.h_k_len = cu_kv_lens.data().get();
rotary_param.batch_size = kBatchSize;
rotary_param.token_num = kTokenNum;
rotary_emb->forward(rotary_param);
params.cos_sin = rotary_emb->cos_sin_;
params.q2p = rotary_emb->q2p_;

// getchar();

Expand Down
6 changes: 3 additions & 3 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ LlamaV2<T>::LlamaV2(const ModelParam& model,
const LoraParam& lora,
const NcclParam& tp,
const Context<T>& ctx,
int max_batch_size,
const EngineParam& engine,
std::shared_ptr<LlamaWeight<T>> weights):
param_(model),
attn_param_(attn),
Expand Down Expand Up @@ -93,7 +93,7 @@ LlamaV2<T>::LlamaV2(const ModelParam& model,
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

unified_decoder_ = std::make_unique<UnifiedDecoder<T>>(model, attn, moe, lora, tp, ctx);
unified_decoder_ = std::make_unique<UnifiedDecoder<T>>(model, attn, moe, lora, tp, engine, ctx);

dynamic_decode_layer_ = std::make_unique<DynamicDecodeLayer<float>>(vocab_size_,
vocab_size_padded_,
Expand All @@ -104,7 +104,7 @@ LlamaV2<T>::LlamaV2(const ModelParam& model,
is_free_buffer_after_forward_,
(cudaDeviceProp*)&ctx.cuda_device_prop);

unified_decoder_->allocateBuffer(max_batch_size);
unified_decoder_->allocateBuffer(engine.max_batch_size);
}

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class LlamaV2 {
const LoraParam& lora,
const NcclParam& tp,
const Context<T>& ctx,
int max_batch_size,
const EngineParam& engine,
std::shared_ptr<LlamaWeight<T>> weights);

size_t vocab_size() const noexcept
Expand Down
Loading

0 comments on commit 795c56f

Please sign in to comment.