From 27a5f52bbd3374b4b090d245f21bdb5960649c12 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Mon, 29 Nov 2021 11:51:50 +0800 Subject: [PATCH] [HeterPs] fix allocation (#37476) * auc temp * cuballocator * code format * code format --- .../framework/fleet/heter_ps/heter_comm.h | 10 ++- .../framework/fleet/heter_ps/heter_comm_inl.h | 75 ++++++++++++------- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 2ec2a8a1f1e..c3fe2dabea9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "cub/cub.cuh" +#include "cub/util_allocator.cuh" #include "hashtable.h" #include "heter_resource.h" #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" @@ -163,9 +164,9 @@ class HeterComm { }; void init_path(); - void create_storage( - int start_index, int end_index, int keylen, int vallen, - std::vector>& local_strorage); + + void create_storage(int start_index, int end_index, int keylen, int vallen); + void destroy_storage(int start_index, int end_index); void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, GradType* src_val); void walk_to_src(int start_index, int gpu_num, int* h_left, int* h_right, @@ -178,7 +179,7 @@ class HeterComm { std::vector tables_; std::shared_ptr resource_; CustomGradMerger merger_; - int topo_aware_{1}; + int topo_aware_{0}; std::vector> path_; std::vector storage_; int feanum_{1800 * 2048}; @@ -186,6 +187,7 @@ class HeterComm { std::vector nccl_inner_comms_; std::vector nccl_inter_comms_; int node_size_; + std::vector> allocators_; }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index d199a39162b..ec852ec83ca 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -100,6 +100,8 @@ HeterComm::HeterComm( storage_.resize(resource_->total_gpu()); for (int i = 0; i < resource_->total_gpu(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(i)); + allocators_.push_back(std::make_shared( + 8, 1, (unsigned int)-1, (size_t)-1, false, false)); auto table = new Table(capacity / load_factor_); tables_.push_back(table); if (multi_node_) { @@ -115,14 +117,14 @@ void HeterComm::init_path() { path_.resize(total_gpu); if (!topo_aware_) { - VLOG(3) << "init path without topo aware"; + VLOG(0) << "init path without topo aware"; for (int i = 0; i < total_gpu; ++i) { path_[i].resize(total_gpu); for (int j = 0; j < total_gpu; ++j) { auto& nodes = path_[i][j].nodes_; nodes.resize(1); nodes[0].in_stream = resource_->comm_stream(i, j); - nodes[0].out_stream = resource_->comm_stream(j, i); + nodes[0].out_stream = resource_->comm_stream(i, j); nodes[0].key_storage = NULL; nodes[0].val_storage = NULL; nodes[0].sync = 0; @@ -130,7 +132,7 @@ void HeterComm::init_path() { } } } else { - VLOG(3) << "init path with topo aware"; + VLOG(0) << "init path with topo aware"; for (int i = 0; i < total_gpu; ++i) { path_[i].resize(total_gpu); for (int j = 0; j < total_gpu; ++j) { @@ -163,26 +165,41 @@ void HeterComm::init_path() { } template -void HeterComm::create_storage( - int start_index, int end_index, int keylen, int vallen, - std::vector>& local_storage) { +void HeterComm::create_storage(int start_index, + int end_index, + int keylen, + int vallen) { + auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; for (size_t i = 0; i < nodes.size(); ++i) { platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); - platform::CUDAPlace remote_place = - platform::CUDAPlace(resource_->dev_id(nodes[i].gpu_num)); - auto key_mem = memory::AllocShared(remote_place, keylen); - local_storage.push_back(key_mem); - nodes[i].key_storage = reinterpret_cast(key_mem->ptr()); - - auto val_mem = memory::AllocShared(remote_place, vallen); - local_storage.push_back(val_mem); - nodes[i].val_storage = reinterpret_cast(val_mem->ptr()); + allocator->DeviceAllocate( + resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].key_storage), + keylen, resource_->remote_stream(nodes[i].gpu_num, start_index)); + allocator->DeviceAllocate( + resource_->dev_id(nodes[i].gpu_num), (void**)&(nodes[i].val_storage), + vallen, resource_->remote_stream(nodes[i].gpu_num, start_index)); + nodes[i].key_bytes_len = keylen; nodes[i].val_bytes_len = vallen; } } +template +void HeterComm::destroy_storage(int start_index, + int end_index) { + auto& allocator = allocators_[start_index]; + auto& nodes = path_[start_index][end_index].nodes_; + for (size_t i = 0; i < nodes.size(); ++i) { + platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num)); + + allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + nodes[i].key_storage); + allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num), + nodes[i].val_storage); + } +} + template void HeterComm::walk_to_dest( int start_index, int gpu_num, int* h_left, int* h_right, KeyType* src_key, @@ -482,8 +499,8 @@ void HeterComm::pull_sparse(int num, int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); - cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); - cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int)); + cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); // auto d_idx = memory::AllocShared(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -505,15 +522,13 @@ void HeterComm::pull_sparse(int num, cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); - std::vector> local_storage; - for (int i = 0; i < total_gpu; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (shard_len == 0) { continue; } create_storage(num, i, shard_len * sizeof(KeyType), - shard_len * sizeof(ValType), local_storage); + shard_len * sizeof(ValType)); } walk_to_dest(num, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL); @@ -533,6 +548,9 @@ void HeterComm::pull_sparse(int num, } for (int i = 0; i < total_gpu; ++i) { cudaStreamSynchronize(resource_->remote_stream(i, num)); + if (h_left[i] == -1) { + continue; + } tables_[i]->rwlock_->UNLock(); } @@ -546,6 +564,9 @@ void HeterComm::pull_sparse(int num, fill_dvals<<>>(d_shard_vals_ptr, d_vals, d_idx_ptr, len); cudaStreamSynchronize(stream); + for (int i = 0; i < total_gpu; ++i) { + destroy_storage(num, i); + } } template @@ -572,8 +593,8 @@ void HeterComm::push_sparse(int gpu_num, int* d_left_ptr = reinterpret_cast(d_left->ptr()); int* d_right_ptr = reinterpret_cast(d_right->ptr()); - cudaMemset(d_left_ptr, -1, total_gpu * sizeof(int)); - cudaMemset(d_right_ptr, -1, total_gpu * sizeof(int)); + cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); // auto d_idx = memory::AllocShared(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); @@ -603,14 +624,13 @@ void HeterComm::push_sparse(int gpu_num, cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); - std::vector> local_storage; for (int i = 0; i < total_gpu; ++i) { int shard_len = h_right[i] - h_left[i] + 1; if (h_left[i] == -1 || h_right[i] == -1) { continue; } create_storage(gpu_num, i, shard_len * sizeof(KeyType), - shard_len * sizeof(GradType), local_storage); + shard_len * sizeof(GradType)); } walk_to_dest(gpu_num, total_gpu, h_left, h_right, d_shard_keys_ptr, @@ -632,7 +652,12 @@ void HeterComm::push_sparse(int gpu_num, } for (int i = 0; i < total_gpu; ++i) { cudaStreamSynchronize(resource_->remote_stream(i, gpu_num)); - tables_[i]->rwlock_->UNLock(); + if (h_left[i] != -1) { + tables_[i]->rwlock_->UNLock(); + } + } + for (int i = 0; i < total_gpu; ++i) { + destroy_storage(gpu_num, i); } } -- GitLab