Skip to content

Commit

Permalink
remove unused
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Nov 25, 2024
1 parent 05d011c commit 7b74b72
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 329 deletions.
19 changes: 3 additions & 16 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,10 @@ struct AttentionParams {
int size_per_head;
float inv_sqrt_dh;

float* cos_sin;
// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
float rope_scaling_factor;
float attention_scaling;
int max_position_embeddings;
float rope_ti_scale; // used for linear RoPE scaling
// the following 3 parameters are used by llama3
float llama3_inv_scaling_factor;
float llama3_alpha;
float llama3_beta;
// the following are use by yarn
float yarn_ramp_inv_factor_div_2;
float yarn_ramp_inv_factor_mul_min;
float yarn_inv_scaling_factor;

float* cos_sin;
int rotary_embedding_dim;
int max_position_embeddings;
// log(n) attention
bool use_logn_attn;

Expand Down
49 changes: 10 additions & 39 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ struct AttentionUniversal {
}
if (params.cos_sin) {
float* cos_sin = params.cos_sin;
const int64_t index = qi * kHeadDim + di;
const int64_t index = qi * params.rotary_embedding_dim + di;
PRAGMA_UNROLL
for (int k = 0; k < kVecSize; k += 4) {
(float4&)vec_cs[s][c][k] = __ldg((const float4*)&cos_sin[index + k]);
Expand All @@ -236,51 +236,22 @@ struct AttentionUniversal {
if (params.cos_sin) {
PrecomputeFastRoPE rope{};
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
for (int s = 0; s < ITER_S; ++s) {
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
rope.apply(vec_Q[s][c], vec_cs[s][c]);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], vec_cs[s][c]);
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
if (di < params.rotary_embedding_dim) {
rope.apply(vec_Q[s][c], vec_cs[s][c]);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], vec_cs[s][c]);
}
}
}
}
}
}

#if 0
const float rope_base = params.rope_theta ? params.rope_theta[batch_idx] : params.rotary_embedding_base;
PRAGMA_UNROLL
for (int c = 0; c < ITER_C; ++c) {
const int di = offset.x + c * Map::kDeltaC;
FastRoPE rope(di,
params.rotary_embedding_dim,
rope_base,
params.rope_ti_scale,
params.rope_scaling_factor,
params.llama3_inv_scaling_factor,
params.llama3_alpha,
params.llama3_beta,
params.yarn_ramp_inv_factor_div_2,
params.yarn_ramp_inv_factor_mul_min,
params.yarn_inv_scaling_factor,
params.attention_scaling,
std::integral_constant<int, kVecSize>{});
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
const int ti = (offset.y + s * Map::kDeltaS) / CTA_H + query_idx + history_len;
rope.apply(vec_Q[s][c], ti);
if constexpr (kProcessKV) {
if (s == 0) {
rope.apply(vec_K[0][c], ti);
}
}
}
}
#endif

if (params.use_logn_attn) {
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
Expand Down
Loading

0 comments on commit 7b74b72

Please sign in to comment.