diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index ac5e03bd4f..566150af9e 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -21,6 +21,7 @@ #include "src/turbomind/models/llama/LlamaWeight.h" #include "src/turbomind/utils/memory_utils.h" #include +#include namespace turbomind { @@ -51,6 +52,17 @@ LlamaWeight::LlamaWeight(size_t head_num, vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_); } + + // try to split embedding table along hidden dim + if (hidden_units_ % tensor_para_size_ == 0) { + embedding_table_size_ = vocab_size_padded_ * hidden_units_ / tensor_para_size_; + } + else { + embedding_table_size_ = vocab_size_padded_ * hidden_units_; + TM_LOG_WARNING( + "Can not split embedding table along hidden_units %d with tp %d", hidden_units_, tensor_para_size_); + } + decoder_layer_weights.reserve(num_layer_); for (unsigned l = 0; l < num_layer_; ++l) { decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(l, @@ -78,6 +90,7 @@ LlamaWeight::~LlamaWeight() cudaFree((void*)post_decoder_embedding_kernel); pre_decoder_embedding_table = nullptr; + output_norm_weight = nullptr; post_decoder_embedding_kernel = nullptr; for (auto& p : decoder_layer_weights) { @@ -89,14 +102,55 @@ template void LlamaWeight::mallocWeights() { FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); - size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ? - vocab_size_padded_ * hidden_units_ / tensor_para_size_ : - vocab_size_padded_ * hidden_units_; - deviceMalloc((T**)&pre_decoder_embedding_table, embedding_table_size); + deviceMalloc((T**)&pre_decoder_embedding_table, embedding_table_size_); deviceMalloc((T**)&output_norm_weight, hidden_units_); deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); } +template +void loadLinearWeights(T* weights, + std::string prefix, + int rank, + size_t tensor_para_size, + size_t dim0, + size_t dim1, + FtCudaDataType type, + size_t split_dim) +{ + FT_CHECK(split_dim == 0 || split_dim == 1); + auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1); + bool enable_slice = true; + if (tensor_para_size <= 1 || std::filesystem::exists(max_prefix + ".weight")) { + enable_slice = false; + } + + std::vector> dims = {dim0, dim1}; + if (dims[split_dim] % tensor_para_size != 0) { + enable_slice = false; + } + else if (!enable_slice && dims[split_dim] % tensor_para_size == 0) { + dims[split_dim] /= tensor_para_size; + } + + prefix += "." + (enable_slice ? std::to_string(0) : std::to_string(rank)); + std::vector weight_slices{}; + if (enable_slice) { + if (split_dim == 0) { + size_t stride = dim0 / tensor_para_size; + ConcateSlice slice0{{{stride * rank, stride * (rank + 1)}}}; + ConcateSlice slice1{{{0, dim1}}}; + weight_slices = {slice0, slice1}; + } + else if (split_dim == 1) { + size_t stride = dim1 / tensor_para_size; + ConcateSlice slice0{{{0, dim0}}}; + ConcateSlice slice1{{{stride * rank, stride * (rank + 1)}}}; + weight_slices = {slice0, slice1}; + } + } + loadWeightFromBin(weights, {dim0, dim1}, prefix + ".weight", type, weight_slices); +} + template void LlamaWeight::loadModel(std::string dir_path) { @@ -106,18 +160,25 @@ void LlamaWeight::loadModel(std::string dir_path) } dir_path += '/'; - size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ? - vocab_size_padded_ * hidden_units_ / tensor_para_size_ : - vocab_size_padded_ * hidden_units_; - loadWeightFromBin( - (T*)pre_decoder_embedding_table, {embedding_table_size}, dir_path + "tok_embeddings.weight", model_file_type); + loadLinearWeights((T*)pre_decoder_embedding_table, + dir_path + "tok_embeddings", + tensor_para_rank_, + tensor_para_size_, + vocab_size_padded_, + hidden_units_, + model_file_type, + 1); loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); - loadWeightFromBin((T*)post_decoder_embedding_kernel, - {hidden_units_ * vocab_size_padded_ / tensor_para_size_}, - dir_path + "output.weight", - model_file_type); + loadLinearWeights((T*)post_decoder_embedding_kernel, + dir_path + "output", + tensor_para_rank_, + tensor_para_size_, + vocab_size_padded_, + hidden_units_, + model_file_type, + 0); for (unsigned layer = 0; layer < num_layer_; ++layer) { decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); @@ -130,12 +191,9 @@ TensorMap LlamaWeight::getParams() TensorMap output; FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); - size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ? - vocab_size_padded_ * hidden_units_ / tensor_para_size_ : - vocab_size_padded_ * hidden_units_; output.insert( "tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight", - Tensor{MEMORY_GPU, getTensorType(), {embedding_table_size * sizeof(T)}, pre_decoder_embedding_table}); + Tensor{MEMORY_GPU, getTensorType(), {embedding_table_size_ * sizeof(T)}, pre_decoder_embedding_table}); output.insert("norm.weight", Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, output_norm_weight}); diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index 8c94925ce7..e3480d406b 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -66,6 +66,7 @@ struct LlamaWeight { size_t inter_size_; size_t vocab_size_; size_t vocab_size_padded_; + size_t embedding_table_size_; size_t num_layer_; WeightType weight_type_; size_t tensor_para_size_;