diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h old mode 100755 new mode 100644 index 8e51f0e2405bfe6ab218148ca5006c210aaa34e7..6d3a4c5d9c0b96d8927a45eeac83350af4983015 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -16,6 +16,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_HETERPS +#include #include #include #include @@ -38,7 +39,7 @@ namespace framework { class HeterContext { public: - ~HeterContext() { + virtual ~HeterContext() { if (!multi_mf_dim_) { for (size_t i = 0; i < mutex_.size(); ++i) { delete mutex_[i]; @@ -56,9 +57,12 @@ class HeterContext { Scope* scope_{nullptr}; std::vector> feature_keys_; std::vector>> feature_dim_keys_; + std::vector>> device_task_keys_; #ifdef PADDLE_WITH_PSLIB std::vector> value_ptr_; + std::vector>> + device_task_ptr_; std::vector>> value_dim_ptr_; std::vector>> @@ -68,6 +72,8 @@ class HeterContext { std::vector> value_ptr_; std::vector>> value_dim_ptr_; + std::vector>> + device_task_ptr_; std::vector>> device_dim_ptr_; #endif @@ -93,6 +99,12 @@ class HeterContext { shard_num_ = shard_num; feature_keys_.resize(shard_num_); value_ptr_.resize(shard_num_); + device_task_ptr_.resize(shard_num_); + device_task_keys_.resize(shard_num_); + for (size_t i = 0; i < device_task_ptr_.size(); i++) { + device_task_ptr_[i].resize(device_num); + device_task_keys_[i].resize(device_num); + } device_values_.resize(device_num); device_keys_.resize(device_num); @@ -108,6 +120,12 @@ class HeterContext { feature_dim_keys_.resize(shard_num_); value_ptr_.resize(shard_num_); value_dim_ptr_.resize(shard_num_); + device_task_ptr_.resize(shard_num_); + device_task_keys_.resize(shard_num_); + for (size_t i = 0; i < device_task_ptr_.size(); i++) { + device_task_ptr_[i].resize(device_num); + device_task_keys_[i].resize(device_num); + } for (size_t i = 0; i < feature_dim_keys_.size(); i++) { feature_dim_keys_[i].resize(dim_num); value_dim_ptr_[i].resize(dim_num); @@ -151,6 +169,12 @@ class HeterContext { for (size_t i = 0; i < device_keys_.size(); ++i) { device_keys_[i].clear(); } + for (size_t i = 0; i < device_task_ptr_.size(); ++i) { + for (size_t j = 0; j < device_task_ptr_[i].size(); ++j) { + device_task_ptr_[i][j].clear(); + device_task_keys_[i][j].clear(); + } + } } else { VLOG(3) << "Reset gpu task with dynamic mf dimention"; for (size_t i = 0; i < feature_dim_keys_.size(); i++) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc old mode 100755 new mode 100644 index e167a39caa526f8c8bb81fca96ef007db007b51a..115ec4d0102cc872bf8136df06cb6fcaf39f3a9f --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -298,6 +298,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { platform::Timer timeline; + std::vector> task_futures; int device_num = heter_devices_.size(); auto& local_keys = gpu_task->feature_keys_; auto& local_ptr = gpu_task->value_ptr_; @@ -316,7 +317,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { device_dim_ptr[dev].resize(multi_mf_dim_); } } - auto& device_mutex = gpu_task->mutex_; + // auto& device_mutex = gpu_task->mutex_; std::vector threads(thread_keys_shard_num_); #ifdef PADDLE_WITH_PSLIB @@ -502,6 +503,8 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { table_id_, pass_id, pass_values); } #endif + auto& device_task_keys = gpu_task->device_task_keys_; + auto& device_task_ptrs = gpu_task->device_task_ptr_; auto build_dynamic_mf_func = [this, device_num, &local_dim_keys, &local_dim_ptr, &device_dim_keys, &device_dim_ptr, @@ -534,17 +537,14 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { #endif }; auto build_func = [device_num, record_status, &pass_values, &local_keys, - &local_ptr, &device_keys, &device_vals, - &device_mutex](int i) { - std::vector> task_keys(device_num); + &local_ptr, &device_task_keys, &device_task_ptrs](int i) { + auto& task_keys = device_task_keys[i]; #ifdef PADDLE_WITH_PSLIB - std::vector> task_ptrs( - device_num); + auto& task_ptrs = device_task_ptrs[i]; #endif #ifdef PADDLE_WITH_PSCORE - std::vector> task_ptrs( - device_num); + auto& task_ptrs = device_task_ptrs[i]; #endif for (size_t j = 0; j < local_keys[i].size(); j++) { @@ -569,88 +569,139 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } } #endif - for (int dev = 0; dev < device_num; dev++) { - device_mutex[dev]->lock(); + }; + if (!multi_mf_dim_) { + for (int i = 0; i < thread_keys_shard_num_; i++) { + task_futures.emplace_back(hbm_thread_pool_[i]->enqueue(build_func, i)); + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); + VLOG(0) << "GpuPs build hbmps done"; + } + std::vector> prefix_sum; + prefix_sum.resize(device_num); + for (int i = 0; i < device_num; i++) { + prefix_sum[i].resize(thread_keys_shard_num_ + 1); + prefix_sum[i][0] = 0; + } + auto calc_prefix_func = [this, &prefix_sum, &device_keys, &device_vals, + &device_task_keys](int device_num) { + for (int j = 0; j < thread_keys_shard_num_; j++) { + prefix_sum[device_num][j + 1] = + prefix_sum[device_num][j] + device_task_keys[j][device_num].size(); + } + device_keys[device_num].resize( + prefix_sum[device_num][thread_keys_shard_num_]); + device_vals[device_num].resize( + prefix_sum[device_num][thread_keys_shard_num_]); + }; + if (!multi_mf_dim_) { + for (int i = 0; i < device_num; i++) { + task_futures.emplace_back( + hbm_thread_pool_[i]->enqueue(calc_prefix_func, i)); + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); + } + VLOG(0) << "prefix done"; + auto prepare_dev_value_func = [device_num, &prefix_sum, &device_keys, + &device_vals, &device_task_keys, + &device_task_ptrs](int dev, int shard_id) { + auto& task_keys = device_task_keys[shard_id]; +#ifdef PADDLE_WITH_PSLIB + auto& task_ptrs = device_task_ptrs[shard_id]; +#endif + +#ifdef PADDLE_WITH_PSCORE + auto& task_ptrs = device_task_ptrs[dev]; +#endif - int len = task_keys[dev].size(); - int cur = device_keys[dev].size(); - device_keys[dev].resize(device_keys[dev].size() + len); - device_vals[dev].resize(device_vals[dev].size() + len); + int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; + int cur = prefix_sum[dev][shard_id]; #ifdef PADDLE_WITH_PSLIB - for (int j = 0; j < len; ++j) { - device_keys[dev][cur + j] = task_keys[dev][j]; - float* ptr_val = task_ptrs[dev][j]->data(); - FeatureValue& val = device_vals[dev][cur + j]; - size_t dim = task_ptrs[dev][j]->size(); - - val.delta_score = ptr_val[1]; - val.show = ptr_val[2]; - val.clk = ptr_val[3]; - val.slot = ptr_val[6]; - val.lr = ptr_val[4]; - val.lr_g2sum = ptr_val[5]; - val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; - } + for (int j = 0; j < len; ++j) { + device_keys[dev][cur + j] = task_keys[dev][j]; + float* ptr_val = task_ptrs[dev][j]->data(); + FeatureValue& val = device_vals[dev][cur + j]; + size_t dim = task_ptrs[dev][j]->size(); + + val.delta_score = ptr_val[1]; + val.show = ptr_val[2]; + val.clk = ptr_val[3]; + val.slot = ptr_val[6]; + val.lr = ptr_val[4]; + val.lr_g2sum = ptr_val[5]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; } } + } #endif #ifdef PADDLE_WITH_PSCORE - for (int j = 0; j < len; ++j) { - device_keys[dev][cur + j] = task_keys[dev][j]; - float* ptr_val = task_ptrs[dev][j]->data(); - FeatureValue& val = device_vals[dev][cur + j]; - size_t dim = task_ptrs[dev][j]->size(); - val.delta_score = ptr_val[2]; - val.show = ptr_val[3]; - val.clk = ptr_val[4]; - val.slot = ptr_val[0]; - val.lr = ptr_val[5]; - val.lr_g2sum = ptr_val[6]; - val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; - } + for (int j = 0; j < len; ++j) { + device_keys[dev][cur + j] = task_keys[dev][j]; + float* ptr_val = task_ptrs[dev][j]->data(); + FeatureValue& val = device_vals[dev][cur + j]; + size_t dim = task_ptrs[dev][j]->size(); + val.delta_score = ptr_val[2]; + val.show = ptr_val[3]; + val.clk = ptr_val[4]; + val.slot = ptr_val[0]; + val.lr = ptr_val[5]; + val.lr_g2sum = ptr_val[6]; + val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); + + if (dim > 7) { + val.mf_size = MF_DIM + 1; + for (int x = 0; x < val.mf_size; x++) { + val.mf[x] = ptr_val[x + 7]; + } + } else { + val.mf_size = 0; + for (int x = 0; x < MF_DIM + 1; x++) { + val.mf[x] = 0; } } + } #endif - VLOG(3) << "GpuPs build hbmps done"; + VLOG(3) << "GpuPs build hbmps done"; - device_mutex[dev]->unlock(); - } }; - if (!multi_mf_dim_) { - for (size_t i = 0; i < threads.size(); i++) { - threads[i] = std::thread(build_func, i); - } - } else { + if (multi_mf_dim_) { for (int i = 0; i < thread_keys_shard_num_; i++) { for (int j = 0; j < multi_mf_dim_; j++) { threads[i * multi_mf_dim_ + j] = std::thread(build_dynamic_mf_func, i, j); } } - } - for (std::thread& t : threads) { - t.join(); + for (std::thread& t : threads) { + t.join(); + } + } else { + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < device_num; j++) { + task_futures.emplace_back( + hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i)); + } + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); } timeline.Pause(); VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() @@ -750,7 +801,7 @@ void PSGPUWrapper::pre_build_thread() { PreBuildTask(gpu_task); timer.Pause(); VLOG(0) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() - << "s"; + << " s"; buildcpu_ready_channel_->Put(gpu_task); } VLOG(3) << "build cpu thread end"; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 9b7d6de082d1c1c6e2b3d16f91522304524df99a..9551e49b6b77b46db81ac5dc224bfbf0fe2b43e1 100755 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -83,6 +83,10 @@ class PSGPUWrapper { PSGPUWrapper() { HeterPs_ = NULL; sleep_seconds_before_fail_exit_ = 300; + hbm_thread_pool_.resize(thread_keys_shard_num_); + for (size_t i = 0; i < hbm_thread_pool_.size(); i++) { + hbm_thread_pool_[i].reset(new ::ThreadPool(1)); + } } void PullSparse(const paddle::platform::Place& place, const int table_id, @@ -399,6 +403,7 @@ class PSGPUWrapper { std::shared_ptr current_task_ = nullptr; std::thread pre_build_threads_; bool running_ = false; + std::vector> hbm_thread_pool_; protected: static bool is_initialized_;