diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 6f063e830c2da736f584d1c6b37aac609b2ca109..1fb2f0fab4aff9926f6ed6c30fa72cb9e9c93cf6 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -77,6 +77,21 @@ class HeterContext { mutex_[i] = new std::mutex(); } } + + void Reset() { + for (size_t i = 0; i < feature_keys_.size(); ++i) { + feature_keys_[i].clear(); + } + for (size_t i = 0; i < value_ptr_.size(); ++i) { + value_ptr_[i].clear(); + } + for (size_t i = 0; i < device_values_.size(); ++i) { + device_values_[i].clear(); + } + for (size_t i = 0; i < device_keys_.size(); ++i) { + device_keys_[i].clear(); + } + } void batch_add_keys( const std::vector>& thread_keys) { assert(thread_keys.size() == feature_keys_.size()); @@ -90,6 +105,15 @@ class HeterContext { } } + void batch_add_keys(int shard_num, + const std::unordered_set& shard_keys) { + int idx = feature_keys_[shard_num].size(); + feature_keys_[shard_num].resize(feature_keys_[shard_num].size() + + shard_keys.size()); + std::copy(shard_keys.begin(), shard_keys.end(), + feature_keys_[shard_num].begin() + idx); + } + void UniqueKeys() { std::vector threads; auto unique_func = [this](int i) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index b7bb5110744649a78c0f2502e7999bb5d3a073f0..67ff6b6acaefb26adc1389559a763b98f41a533a 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -103,12 +103,26 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, timeline.Start(); + threads.clear(); // merge thread_keys to shard_keys - for (size_t i = 0; i < thread_keys_.size(); i++) { - gpu_task->batch_add_keys(thread_keys_[i]); - for (int j = 0; j < thread_keys_thread_num_; j++) { - thread_keys_[i][j].clear(); + auto merge_ins_func = [this, gpu_task](int shard_num) { + for (int i = 0; i < thread_keys_thread_num_; ++i) { + gpu_task->batch_add_keys(shard_num, thread_keys_[i][shard_num]); + thread_keys_[i][shard_num].clear(); } + }; + + // for (size_t i = 0; i < thread_keys_.size(); i++) { + // gpu_task->batch_add_keys(thread_keys_[i]); + // for (int j = 0; j < thread_keys_thread_num_; j++) { + // thread_keys_[i][j].clear(); + // } + //} + for (int i = 0; i < thread_keys_shard_num_; ++i) { + threads.push_back(std::thread(merge_ins_func, i)); + } + for (auto& t : threads) { + t.join(); } timeline.Pause(); @@ -261,6 +275,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task, void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { int device_num = heter_devices_.size(); std::shared_ptr gpu_task = gpu_task_pool_.Get(); + gpu_task->Reset(); BuildTask(gpu_task, table_id, feature_dim); platform::Timer timeline; timeline.Start(); @@ -273,8 +288,8 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { size_max = std::max(size_max, feature_keys_count[i]); } if (HeterPs_) { - HeterPs_->show_one_table(0); - return; + delete HeterPs_; + HeterPs_ = nullptr; } std::vector threads(device_num); HeterPs_ = HeterPsBase::get_instance(size_max, resource_); @@ -295,6 +310,7 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { timeline.Pause(); VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec() << " s."; + gpu_task_pool_.Push(gpu_task); } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,