Skip to content

Commit

Permalink
fix loading from workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Sep 2, 2024
1 parent 78c06e7 commit 918acda
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 17 deletions.
92 changes: 75 additions & 17 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/utils/memory_utils.h"
#include <cuda_runtime.h>
#include <filesystem>

namespace turbomind {

Expand Down Expand Up @@ -51,6 +52,17 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_);
}

// try to split embedding table along hidden dim
if (hidden_units_ % tensor_para_size_ == 0) {
embedding_table_size_ = vocab_size_padded_ * hidden_units_ / tensor_para_size_;
}
else {
embedding_table_size_ = vocab_size_padded_ * hidden_units_;
TM_LOG_WARNING(
"Can not split embedding table along hidden_units %d with tp %d", hidden_units_, tensor_para_size_);
}

decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(l,
Expand Down Expand Up @@ -78,6 +90,7 @@ LlamaWeight<T>::~LlamaWeight()
cudaFree((void*)post_decoder_embedding_kernel);

pre_decoder_embedding_table = nullptr;
output_norm_weight = nullptr;
post_decoder_embedding_kernel = nullptr;

for (auto& p : decoder_layer_weights) {
Expand All @@ -89,14 +102,55 @@ template<typename T>
void LlamaWeight<T>::mallocWeights()
{
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);
size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
deviceMalloc((T**)&pre_decoder_embedding_table, embedding_table_size);
deviceMalloc((T**)&pre_decoder_embedding_table, embedding_table_size_);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_);
}

template<typename T>
void loadLinearWeights(T* weights,
std::string prefix,
int rank,
size_t tensor_para_size,
size_t dim0,
size_t dim1,
FtCudaDataType type,
size_t split_dim)
{
FT_CHECK(split_dim == 0 || split_dim == 1);
auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1);
bool enable_slice = true;
if (tensor_para_size <= 1 || std::filesystem::exists(max_prefix + ".weight")) {
enable_slice = false;
}

std::vector<std::reference_wrapper<size_t>> dims = {dim0, dim1};
if (dims[split_dim] % tensor_para_size != 0) {
enable_slice = false;
}
else if (!enable_slice && dims[split_dim] % tensor_para_size == 0) {
dims[split_dim] /= tensor_para_size;
}

prefix += "." + (enable_slice ? std::to_string(0) : std::to_string(rank));
std::vector<ConcateSlice> weight_slices{};
if (enable_slice) {
if (split_dim == 0) {
size_t stride = dim0 / tensor_para_size;
ConcateSlice slice0{{{stride * rank, stride * (rank + 1)}}};
ConcateSlice slice1{{{0, dim1}}};
weight_slices = {slice0, slice1};
}
else if (split_dim == 1) {
size_t stride = dim1 / tensor_para_size;
ConcateSlice slice0{{{0, dim0}}};
ConcateSlice slice1{{{stride * rank, stride * (rank + 1)}}};
weight_slices = {slice0, slice1};
}
}
loadWeightFromBin(weights, {dim0, dim1}, prefix + ".weight", type, weight_slices);
}

template<typename T>
void LlamaWeight<T>::loadModel(std::string dir_path)
{
Expand All @@ -106,18 +160,25 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
}
dir_path += '/';

size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
loadWeightFromBin(
(T*)pre_decoder_embedding_table, {embedding_table_size}, dir_path + "tok_embeddings.weight", model_file_type);
loadLinearWeights((T*)pre_decoder_embedding_table,
dir_path + "tok_embeddings",
tensor_para_rank_,
tensor_para_size_,
vocab_size_padded_,
hidden_units_,
model_file_type,
1);

loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);

loadWeightFromBin((T*)post_decoder_embedding_kernel,
{hidden_units_ * vocab_size_padded_ / tensor_para_size_},
dir_path + "output.weight",
model_file_type);
loadLinearWeights((T*)post_decoder_embedding_kernel,
dir_path + "output",
tensor_para_rank_,
tensor_para_size_,
vocab_size_padded_,
hidden_units_,
model_file_type,
0);

for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
Expand All @@ -130,12 +191,9 @@ TensorMap LlamaWeight<T>::getParams()
TensorMap output;
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);

size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
output.insert(
"tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight",
Tensor{MEMORY_GPU, getTensorType<T>(), {embedding_table_size * sizeof(T)}, pre_decoder_embedding_table});
Tensor{MEMORY_GPU, getTensorType<T>(), {embedding_table_size_ * sizeof(T)}, pre_decoder_embedding_table});

output.insert("norm.weight",
Tensor{MEMORY_GPU, getTensorType<T>(), {hidden_units_ * sizeof(T)}, output_norm_weight});
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct LlamaWeight {
size_t inter_size_;
size_t vocab_size_;
size_t vocab_size_padded_;
size_t embedding_table_size_;
size_t num_layer_;
WeightType weight_type_;
size_t tensor_para_size_;
Expand Down

0 comments on commit 918acda

Please sign in to comment.