Skip to content

Commit

Permalink
rename function
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 4, 2024
1 parent 6019a23 commit d9d5a38
Showing 1 changed file with 47 additions and 47 deletions.
94 changes: 47 additions & 47 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ __inline__ __device__ float compute_default_parameters(float base, float dim, in
}

template<typename T>
__global__ void computeCosSinDefault(
__global__ void rotaryEmbeddingDefault(
const float* rope_base, int* q_len, int* k_len, int token_num, int batch_size, int dim, float factor, T* cos_sin)
{
int qi = blockIdx.x;
Expand All @@ -46,16 +46,16 @@ __global__ void computeCosSinDefault(
}

template<typename T>
__global__ void computeCosSinLlama3(const float* rope_base,
int* q_len,
int* k_len,
int token_num,
int batch_size,
int dim,
float inv_scaling_factor,
float alpha,
float beta,
T* cos_sin)
__global__ void rotaryEmbeddingLlama3(const float* rope_base,
int* q_len,
int* k_len,
int token_num,
int batch_size,
int dim,
float inv_scaling_factor,
float alpha,
float beta,
T* cos_sin)
{
int qi = blockIdx.x;
int di = threadIdx.x;
Expand All @@ -75,17 +75,17 @@ __global__ void computeCosSinLlama3(const float* rope_base,
}

template<typename T>
__global__ void computeCosSinYarn(const float* rope_base,
int* q_len,
int* k_len,
int token_num,
int batch_size,
int dim,
float ramp_inv_factor_div_2,
float ramp_inv_factor_mul_min,
float inv_scaling_factor,
float attention_scaling,
T* cos_sin)
__global__ void rotaryEmbeddingYarn(const float* rope_base,
int* q_len,
int* k_len,
int token_num,
int batch_size,
int dim,
float ramp_inv_factor_div_2,
float ramp_inv_factor_mul_min,
float inv_scaling_factor,
float attention_scaling,
T* cos_sin)
{
int qi = blockIdx.x;
int di = threadIdx.x;
Expand Down Expand Up @@ -193,40 +193,40 @@ void RotaryEmbeddingV2<T>::forward(const RotaryEmbeddingV2Param& params)
case RopeType::kDefault:
case RopeType::kLinear:
case RopeType::kDynamic:
computeCosSinDefault<<<grid, block, 0, stream_>>>(params.rope_theta,
params.q_len,
params.k_ken,
params.token_num,
params.batch_size,
dim_,
inv_factor_,
cos_sin_);
rotaryEmbeddingDefault<<<grid, block, 0, stream_>>>(params.rope_theta,
params.q_len,
params.k_ken,
params.token_num,
params.batch_size,
dim_,
inv_factor_,
cos_sin_);
break;
case RopeType::kLlama3:
computeCosSinLlama3<<<grid, block, 0, stream_>>>(params.rope_theta,
rotaryEmbeddingLlama3<<<grid, block, 0, stream_>>>(params.rope_theta,
params.q_len,
params.k_ken,
params.token_num,
params.batch_size,
dim_,
llama3_.inv_scaling_factor,
llama3_.alpha,
llama3_.beta,
cos_sin_);
break;
case RopeType::kYarn:
rotaryEmbeddingYarn<<<grid, block, 0, stream_>>>(params.rope_theta,
params.q_len,
params.k_ken,
params.token_num,
params.batch_size,
dim_,
llama3_.inv_scaling_factor,
llama3_.alpha,
llama3_.beta,
yarn_.ramp_inv_factor_div_2,
yarn_.ramp_inv_factor_mul_min,
yarn_.inv_scaling_factor,
yarn_.attention_factor,
cos_sin_);
break;
case RopeType::kYarn:
computeCosSinYarn<<<grid, block, 0, stream_>>>(params.rope_theta,
params.q_len,
params.k_ken,
params.token_num,
params.batch_size,
dim_,
yarn_.ramp_inv_factor_div_2,
yarn_.ramp_inv_factor_mul_min,
yarn_.inv_scaling_factor,
yarn_.attention_factor,
cos_sin_);
break;
default:
FT_CHECK(0);
}
Expand Down

0 comments on commit d9d5a38

Please sign in to comment.