Skip to content

Commit

Permalink
fix overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 28, 2023
1 parent 5c9b832 commit e551dad
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/turbomind/kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ __global__ void generic_activation(T* out,
using Float_T = typename packed_as<float, packed_elems>::type;
using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type;

for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) {
T val;
if (int8_mode == 2) {
// val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(out)[id]) * activation_in[0]);
Expand Down Expand Up @@ -275,7 +275,7 @@ void invokeGenericActivation(T* out,
}
else {
block.x = n_threads;
grid.x = ceil(m * n / double(n_threads));
grid.x = ceil(1LL * m * n / double(n_threads));
}
TM_LOG_DEBUG("%d %d", grid.x, block.x);
sync_check_cuda_error();
Expand Down

0 comments on commit e551dad

Please sign in to comment.