From c0001a2433c1058ebfd21df22fe0f86146f16610 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 23 May 2022 11:49:19 +0800 Subject: [PATCH] Acc name (#42906) add dymf support of gpups --- paddle/fluid/framework/fleet/heter_context.h | 18 ---- .../framework/fleet/heter_ps/feature_value.h | 14 +++ .../fleet/heter_ps/hashtable_kernel.cu | 15 ++- .../framework/fleet/heter_ps/heter_comm_inl.h | 43 ++++++++- .../fleet/heter_ps/heter_comm_kernel.cu | 1 + .../fluid/framework/fleet/ps_gpu_wrapper.cc | 94 +++++++------------ paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 11 +++ 7 files changed, 114 insertions(+), 82 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 11217b6c485..823b60c5ef1 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -95,24 +95,6 @@ class HeterContext { } void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; } uint32_t ShardNum() { return shard_num_; } - void init(int shard_num, int device_num) { - shard_num_ = shard_num; - feature_keys_.resize(shard_num_); - value_ptr_.resize(shard_num_); - device_task_ptr_.resize(shard_num_); - device_task_keys_.resize(shard_num_); - for (size_t i = 0; i < device_task_ptr_.size(); i++) { - device_task_ptr_[i].resize(device_num); - device_task_keys_[i].resize(device_num); - } - - device_values_.resize(device_num); - device_keys_.resize(device_num); - mutex_.resize(device_num); - for (size_t i = 0; i < mutex_.size(); ++i) { - mutex_[i] = new std::mutex(); - } - } void init(int shard_num, int device_num, int dim_num) { shard_num_ = shard_num; diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 682c4568cb7..cb7f3a40d67 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -69,6 +69,20 @@ struct FeaturePushValue { int mf_dim; float mf_g[0]; + __device__ __forceinline__ FeaturePushValue + operator+(const FeaturePushValue& a) const { + FeaturePushValue out; + out.slot = a.slot; + out.mf_dim = a.mf_dim; + out.show = a.show + show; + out.clk = a.clk + clk; + out.lr_g = a.lr_g + lr_g; + // out.mf_g = a.mf_g; + for (int i = 0; i < out.mf_dim; ++i) { + out.mf_g[i] = a.mf_g[i] + mf_g[i]; + } + return out; + } __device__ __forceinline__ void operator=(const FeaturePushValue& in) { show = in.show; clk = in.clk; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 32dbd98992b..f5807d2fd7e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -86,13 +86,26 @@ __global__ void dy_mf_search_kernel(Table* table, char* vals, size_t len, size_t pull_feature_value_size) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + // return; if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { uint64_t offset = i * pull_feature_value_size; - FeatureValue& cur = *(FeatureValue*)(vals + offset); + FeatureValue* cur = (FeatureValue*)(vals + offset); FeatureValue& input = *(FeatureValue*)(it->second); + cur->slot = input.slot; + cur->show = input.show; + cur->clk = input.clk; + cur->mf_dim = input.mf_dim; + cur->lr = input.lr; + cur->mf_size = input.mf_size; + cur->cpu_ptr = input.cpu_ptr; + cur->delta_score = input.delta_score; + cur->lr_g2sum = input.lr_g2sum; + for (int j = 0; j < cur->mf_dim + 1; ++j) { + cur->mf[j] = input.mf[j]; + } } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 506a0c0b186..64b177abb86 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -26,6 +26,7 @@ namespace framework { template HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { + VLOG(1) << "Construct new HeterComm"; resource_ = resource; storage_.resize(resource_->total_device()); multi_mf_dim_ = resource->multi_mf(); @@ -364,6 +365,10 @@ HeterComm::~HeterComm() { delete table; table = nullptr; } + for (auto& table : tables_) { + delete table; + table = nullptr; + } } } @@ -473,17 +478,23 @@ void HeterComm::build_ps(int num, KeyType* h_keys, return; } int dev_id = resource_->dev_id(num); + DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); + + // use hbm pool std::vector d_key_bufs; + ppStream streams[stream_num]; // NOLINT for (int i = 0; i < stream_num; ++i) { create_stream(&(streams[i])); auto d_k_buf = memory::Alloc(place, chunk_size * sizeof(KeyType)); d_key_bufs.push_back(std::move(d_k_buf)); } + int cur_len = 0; int cur_stream = 0; + while (cur_len < len) { cur_stream = cur_stream % stream_num; auto cur_use_stream = streams[cur_stream]; @@ -491,8 +502,10 @@ void HeterComm::build_ps(int num, KeyType* h_keys, cur_use_stream = 0; #endif int tmp_len = cur_len + chunk_size > len ? len - cur_len : chunk_size; + auto dst_place = place; auto src_place = platform::CPUPlace(); + memory_copy( dst_place, reinterpret_cast(d_key_bufs[cur_stream]->ptr()), src_place, h_keys + cur_len, sizeof(KeyType) * tmp_len, cur_use_stream); @@ -557,14 +570,20 @@ void HeterComm::dynamic_merge_grad( platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDADeviceGuard guard(dev_id); auto stream = resource_->local_stream(gpu_num, 0); + size_t temp_storage_bytes; + + // VLOG(1) << "hetercomm merge_grad: max_mf_dim: " << max_mf_dim_; size_t grad_value_size = TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); + auto d_merge_grads = memory::Alloc(place, len * grad_value_size); GradType* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); + auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = reinterpret_cast(d_fea_num_info->ptr()); @@ -836,9 +855,16 @@ void HeterComm::push_sparse(int dev_num, auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + GradType* d_shard_grads_ptr; - auto d_shard_grads = memory::Alloc(place, len * grad_value_size); - d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); + if (!multi_mf_dim_) { + auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType)); + d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); + } else { + auto d_shard_grads = memory::Alloc(place, len * grad_value_size); + d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); + } + int uniq_len = len; dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); @@ -846,9 +872,16 @@ void HeterComm::push_sparse(int dev_num, split_input_to_shard(d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); - heter_comm_kernel_->dy_mf_fill_shard_grads( - d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, uniq_len, - grad_value_size, stream); + + if (!multi_mf_dim_) { + heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, d_keys, + d_shard_grads_ptr, d_grads, d_idx_ptr, + uniq_len, stream); + } else { + heter_comm_kernel_->dy_mf_fill_shard_grads( + d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, d_idx_ptr, + uniq_len, grad_value_size, stream); + } sync_stream(stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index f44803982a5..94d7929b294 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -136,6 +136,7 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, size_t grad_value_size, DynamicGradMerger& merger_) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { uint32_t start = offset[i]; uint32_t num = fea_num[i]; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index a22704bd1ed..18eec174fe9 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -106,25 +106,17 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { platform::Timer timeline; timeline.Start(); int device_num = heter_devices_.size(); - if (!multi_mf_dim_) { - gpu_task->init(thread_keys_shard_num_, device_num); - } else { - gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); - } + gpu_task->init(thread_keys_shard_num_, device_num, multi_mf_dim_); std::vector threads; - if (!multi_mf_dim_) { - thread_keys_.resize(thread_keys_thread_num_); - for (int i = 0; i < thread_keys_thread_num_; i++) { - thread_keys_[i].resize(thread_keys_shard_num_); - } - } else { - thread_dim_keys_.resize(thread_keys_thread_num_); - for (int i = 0; i < thread_keys_thread_num_; i++) { - thread_dim_keys_[i].resize(thread_keys_shard_num_); - for (int j = 0; j < thread_keys_shard_num_; j++) { - thread_dim_keys_[i][j].resize(multi_mf_dim_); - } + + // data should be in input channel + + thread_dim_keys_.resize(thread_keys_thread_num_); + for (int i = 0; i < thread_keys_thread_num_; i++) { + thread_dim_keys_[i].resize(thread_keys_shard_num_); + for (int j = 0; j < thread_keys_shard_num_; j++) { + thread_dim_keys_[i][j].resize(multi_mf_dim_); } } @@ -144,18 +136,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { len_per_thread = total_len / thread_keys_thread_num_; remain = total_len % thread_keys_thread_num_; VLOG(0) << "total len: " << total_len; - auto gen_func = [this](const std::deque& total_data, - int begin_index, int end_index, int i) { - for (auto iter = total_data.begin() + begin_index; - iter != total_data.begin() + end_index; iter++) { - const auto& ins = *iter; - const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values; - for (const auto feasign : feasign_v) { - int shard_id = feasign % thread_keys_shard_num_; - this->thread_keys_[i][shard_id].insert(feasign); - } - } - }; auto gen_dynamic_mf_func = [this](const std::deque& total_data, int begin_index, int end_index, int i) { for (auto iter = total_data.begin() + begin_index; @@ -177,17 +157,10 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } }; for (int i = 0; i < thread_keys_thread_num_; i++) { - if (!multi_mf_dim_) { - VLOG(0) << "yxf::psgpu wrapper genfunc"; - threads.push_back( - std::thread(gen_func, std::ref(vec_data), begin, - begin + len_per_thread + (i < remain ? 1 : 0), i)); - } else { - VLOG(0) << "yxf::psgpu wrapper genfunc with dynamic mf"; - threads.push_back( - std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin, - begin + len_per_thread + (i < remain ? 1 : 0), i)); - } + threads.push_back( + std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), i)); + begin += len_per_thread + (i < remain ? 1 : 0); } for (std::thread& t : threads) { @@ -235,12 +208,6 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { threads.clear(); // merge thread_keys to shard_keys - auto merge_ins_func = [this, gpu_task](int shard_num) { - for (int i = 0; i < thread_keys_thread_num_; ++i) { - gpu_task->batch_add_keys(shard_num, thread_keys_[i][shard_num]); - thread_keys_[i][shard_num].clear(); - } - }; auto merge_ins_dynamic_mf_func = [this, gpu_task](int shard_num, int dim_id) { for (int i = 0; i < thread_keys_thread_num_; ++i) { gpu_task->batch_add_keys(shard_num, dim_id, @@ -249,12 +216,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } }; for (int i = 0; i < thread_keys_shard_num_; ++i) { - if (!multi_mf_dim_) { - threads.push_back(std::thread(merge_ins_func, i)); - } else { - for (int j = 0; j < multi_mf_dim_; j++) { - threads.push_back(std::thread(merge_ins_dynamic_mf_func, i, j)); - } + for (int j = 0; j < multi_mf_dim_; j++) { + threads.push_back(std::thread(merge_ins_dynamic_mf_func, i, j)); } } for (auto& t : threads) { @@ -297,12 +260,12 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { auto& device_dim_keys = gpu_task->device_dim_keys_; auto& device_dim_ptr = gpu_task->device_dim_ptr_; auto& device_dim_mutex = gpu_task->dim_mutex_; - if (multi_mf_dim_) { - for (size_t dev = 0; dev < device_dim_keys.size(); dev++) { - device_dim_keys[dev].resize(multi_mf_dim_); - device_dim_ptr[dev].resize(multi_mf_dim_); - } + + for (size_t dev = 0; dev < device_dim_keys.size(); dev++) { + device_dim_keys[dev].resize(multi_mf_dim_); + device_dim_ptr[dev].resize(multi_mf_dim_); } + // auto& device_mutex = gpu_task->mutex_; std::vector threads(thread_keys_shard_num_); @@ -415,6 +378,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { task_keys[shard].push_back(local_dim_keys[i][j][k]); task_ptrs[shard].push_back(local_dim_ptr[i][j][k]); } + // allocate local keys to devices for (int dev = 0; dev < device_num; dev++) { device_dim_mutex[dev][j]->lock(); int len = task_keys[dev].size(); @@ -619,6 +583,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { << feature_keys_count[i]; size_max = std::max(size_max, feature_keys_count[i]); } + if (HeterPs_) { delete HeterPs_; HeterPs_ = nullptr; @@ -665,6 +630,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { ptr_val[paddle::ps::DownpourCtrDymfAccessor:: DownpourCtrDymfFeatureValue::embed_g2sum_index()]; val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]); + + // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: mf_dim_index()] = float(mf_dim); val->mf_dim = mf_dim; @@ -681,11 +648,15 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } } } + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + this->hbm_pools_[i * this->multi_mf_dim_ + j] = new HBMMemoryPool(mem_pool); auto& cur_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + this->HeterPs_->build_ps(i, device_dim_keys.data(), cur_pool->mem(), len, feature_value_size, 500000, 2); + if (device_dim_keys.size() > 0) { VLOG(0) << "show ptr table: " << i << " table kv size: " << device_dim_keys.size() @@ -700,6 +671,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { threads[i + j * device_num] = std::thread(build_dynamic_mf_func, i, j); } } + for (std::thread& t : threads) { t.join(); } @@ -723,7 +695,9 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { InitSlotInfo(); std::shared_ptr gpu_task = gpu_task_pool_.Get(); gpu_task->Reset(); + data_ready_channel_->Put(gpu_task); + VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]"; } @@ -805,6 +779,7 @@ void PSGPUWrapper::EndPass() { timer.Start(); size_t keysize_max = 0; // in case of feasign_num = 0, skip dump_to_cpu + for (size_t i = 0; i < heter_devices_.size(); i++) { for (int j = 0; j < multi_mf_dim_; j++) { keysize_max = @@ -821,9 +796,11 @@ void PSGPUWrapper::EndPass() { VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim; size_t feature_value_size = TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + char* test_build_values = (char*)malloc(feature_value_size * len); cudaMemcpy(test_build_values, hbm_pool->mem(), feature_value_size * len, cudaMemcpyDeviceToHost); + CHECK(len == hbm_pool->capacity()); #ifdef PADDLE_WITH_PSLIB uint64_t unuse_key = std::numeric_limits::max(); @@ -972,7 +949,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, feature_value_size = TYPEALIGN( 8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1)); - VLOG(0) << "yxf pull sparse feature_value_size: " << feature_value_size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; @@ -1159,6 +1135,8 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, "GPUPS: PushSparseGrad Only Support CUDAPlace Now.")); } all_timer.Pause(); + time_3 += all_timer.ElapsedSec(); + time_4 += push_gpups_timer.ElapsedSec(); VLOG(3) << "PushSparseGrad total cost: " << all_timer.ElapsedSec() << " s, of which GPUPS cost: " << push_gpups_timer.ElapsedSec() << " s"; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 9b556266459..0efec57e59d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -333,6 +333,11 @@ class PSGPUWrapper { void SetSlotOffsetVector(const std::vector& slot_offset_vector) { slot_offset_vector_ = slot_offset_vector; + std::cout << "yxf set: "; + for (auto s : slot_offset_vector_) { + std::cout << s << " | "; + } + std::cout << " end " << std::endl; } #ifdef PADDLE_WITH_CUDA @@ -431,6 +436,12 @@ class PSGPUWrapper { int max_mf_dim_{0}; size_t val_type_size_{0}; size_t grad_type_size_{0}; + + double time_1 = 0.0; + double time_2 = 0.0; + double time_3 = 0.0; + double time_4 = 0.0; + int multi_node_{0}; int node_size_; uint64_t table_id_; -- GitLab