Skip to content

Commit

Permalink
support turbomind backend min_p
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Sep 4, 2024
1 parent 599ce36 commit aeae1cb
Show file tree
Hide file tree
Showing 32 changed files with 1,933 additions and 4,842 deletions.
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,9 @@ endif()
########################################

add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BaseSamplingLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
$<TARGET_OBJECTS:activation_kernels>
$<TARGET_OBJECTS:ban_bad_words>
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
dtype=np.uint32),
runtime_top_k=_broadcast_np(gen_config.top_k, np.uint32),
runtime_top_p=_broadcast_np(gen_config.top_p, np.float32),
runtime_min_p=_broadcast_np(gen_config.min_p, np.float32),
temperature=_broadcast_np(gen_config.temperature, np.float32),
repetition_penalty=_broadcast_np(gen_config.repetition_penalty,
np.float32),
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

add_library(sampling_kernels STATIC sampling_kernels.cu)
set_property(TARGET sampling_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET sampling_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)

if (BUILD_TEST)
add_subdirectory(flash_attention)
endif ()
Expand Down
108 changes: 108 additions & 0 deletions src/turbomind/kernels/sampling_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11000)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/turbomind/kernels/sampling_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/utils/constant.h"

namespace turbomind {

template<typename T, int BLOCK_SIZE>
__global__ void sampling(T* logits,
const int stride,
int* indices,
int* kept,
curandState_t* curandstate,
bool* finished,
const int* end_ids,
int* output_ids,
int* sequence_length,
float* sampled_logprobs,
uint32_t* sampled_indexes,
uint32_t* sampled_nums)
{
int tid = threadIdx.x;
int batch_id = blockIdx.x;
int n = kept[batch_id];

logits += stride * batch_id;
indices += stride * batch_id;

__shared__ float rand_num_s;
__shared__ int selected;
if (tid == 0) {
rand_num_s = curand_uniform(curandstate + batch_id);
}
__syncthreads();

typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;

float local_rand = rand_num_s;
float prefix_sum = 0.f;
BlockPrefixCallbackOp prefix_op{0};
int end = (n + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_logit = (i < n) ? static_cast<float>(logits[i]) : 0.f;
BlockScan(temp_storage).InclusiveSum(thread_logit, prefix_sum, prefix_op);
auto count = __syncthreads_count(prefix_sum > local_rand);
if (count != 0 || (i + BLOCK_SIZE) >= end) {
if (tid == min(BLOCK_SIZE - count, BLOCK_SIZE - 1)) {
selected = min(i, n - 1);
output_ids[batch_id] = indices[selected];

if (sequence_length != nullptr && finished != nullptr) {
sequence_length[batch_id] =
finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished[batch_id] = output_ids[batch_id] == end_ids[batch_id] ? 1 : 0;
}
}
break;
}
}

if (sampled_logprobs != nullptr && sampled_indexes != nullptr && sampled_nums != nullptr) {
__syncthreads();
sampled_logprobs += batch_id * kMaxLogProb;
sampled_indexes += batch_id * kMaxLogProb;
int end = min(n, kMaxLogProb);
for (int i = tid; i < end; i += BLOCK_SIZE) {
sampled_logprobs[i] = logf(logits[i]);
sampled_indexes[i] = indices[i];
}
if (n > kMaxLogProb && selected >= kMaxLogProb) {
if ((kMaxLogProb - 1 + BLOCK_SIZE - tid) % BLOCK_SIZE == 0) {
sampled_logprobs[kMaxLogProb - 1] = logf(logits[selected]);
sampled_indexes[kMaxLogProb - 1] = indices[selected];
}
}
sampled_nums[batch_id] = min(n, kMaxLogProb);
}
}

template<typename T>
void invokeSampling(SamplingParams& params, cudaStream_t stream)
{
const int grid = params.batch_size;
const int block = 256;
sampling<T, block><<<grid, block, 0, stream>>>((T*)params.logits,
params.stride,
params.indices,
params.kept,
params.curandstate,
params.finished,
params.end_ids,
params.output_ids,
params.sequence_length,
params.sampled_logprobs,
params.sampled_indexes,
params.sampled_nums);
}

template void invokeSampling<float>(SamplingParams& params, cudaStream_t stream);

} // namespace turbomind
44 changes: 44 additions & 0 deletions src/turbomind/kernels/sampling_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <stdint.h>

namespace turbomind {

struct SamplingParams {
void* logits;
int stride;
int* indices;
int* kept;
curandState_t* curandstate;
size_t batch_size;
bool* finished;
int* end_ids;
int* output_ids;
int* sequence_length;
float* sampled_logprobs{};
uint32_t* sampled_indexes{};
uint32_t* sampled_nums{};
};

template<typename T>
void invokeSampling(SamplingParams& params, cudaStream_t stream);

} // namespace turbomind
Loading

0 comments on commit aeae1cb

Please sign in to comment.