diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index d3990c1f3dd769b1882385b572ece9717c30887b..4fb98e526d5fc45aeb9f674dbb2a3447464e2819 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -40,7 +40,7 @@ namespace framework { std::shared_ptr PSGPUWrapper::s_instance_ = NULL; bool PSGPUWrapper::is_initialized_ = false; -void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task) { +void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; platform::Timer timeline; timeline.Start(); @@ -49,17 +49,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task) { auto& local_keys = gpu_task->feature_keys_; auto& local_ptr = gpu_task->value_ptr_; - auto& device_keys = gpu_task->device_keys_; - auto& device_vals = gpu_task->device_values_; - auto& device_mutex = gpu_task->mutex_; - std::vector threads; -#ifdef PADDLE_WITH_PSLIB - auto fleet_ptr = FleetWrapper::GetInstance(); -#endif -#ifdef PADDLE_WITH_PSCORE - auto fleet_ptr = paddle::distributed::Communicator::GetInstance(); -#endif // data should be in input channel thread_keys_.resize(thread_keys_thread_num_); @@ -181,6 +171,25 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task) { VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size(); local_ptr[i].resize(local_keys[i].size()); } +} + +void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { + platform::Timer timeline; + int device_num = heter_devices_.size(); + auto& local_keys = gpu_task->feature_keys_; + auto& local_ptr = gpu_task->value_ptr_; + + auto& device_keys = gpu_task->device_keys_; + auto& device_vals = gpu_task->device_values_; + auto& device_mutex = gpu_task->mutex_; + + std::vector threads(thread_keys_shard_num_); +#ifdef PADDLE_WITH_PSLIB + auto fleet_ptr = FleetWrapper::GetInstance(); +#endif +#ifdef PADDLE_WITH_PSCORE + auto fleet_ptr = paddle::distributed::Communicator::GetInstance(); +#endif #ifdef PADDLE_WITH_PSLIB // get day_id: day nums from 1970 @@ -482,29 +491,32 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { void PSGPUWrapper::start_build_thread() { running_ = true; VLOG(3) << "start build CPU&GPU ps thread."; - build_cpu_threads_ = std::thread([this] { build_cpu_thread(); }); - build_gpu_threads_ = std::thread([this] { build_gpu_thread(); }); + pre_build_threads_ = std::thread([this] { pre_build_thread(); }); + build_threads_ = std::thread([this] { build_thread(); }); } -void PSGPUWrapper::build_cpu_thread() { +void PSGPUWrapper::pre_build_thread() { + // prebuild: process load_data while (running_) { std::shared_ptr gpu_task = nullptr; if (!data_ready_channel_->Get(gpu_task)) { continue; } - VLOG(3) << "thread BuildTask start."; + VLOG(3) << "thread PreBuildTask start."; platform::Timer timer; timer.Start(); // build cpu ps data process - BuildTask(gpu_task); + PreBuildTask(gpu_task); timer.Pause(); - VLOG(1) << "thread BuildTask end, cost time: " << timer.ElapsedSec() << "s"; + VLOG(1) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec() + << "s"; buildcpu_ready_channel_->Put(gpu_task); } VLOG(3) << "build cpu thread end"; } -void PSGPUWrapper::build_gpu_thread() { +void PSGPUWrapper::build_thread() { + // build: build_pull + build_gputask while (running_) { std::shared_ptr gpu_task = nullptr; if (!gpu_free_channel_->Get(gpu_task)) { @@ -516,12 +528,14 @@ void PSGPUWrapper::build_gpu_thread() { VLOG(3) << "thread BuildGPUTask start."; platform::Timer timer; timer.Start(); + BuildPull(gpu_task); + timer.Pause(); + timer.Start(); BuildGPUTask(gpu_task); timer.Pause(); VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec() << "s"; - gpu_task_pool_.Push(gpu_task); train_ready_channel_->Put(gpu_task); } VLOG(3) << "build gpu thread end"; @@ -557,6 +571,8 @@ void PSGPUWrapper::EndPass() { if (keysize_max != 0) { HeterPs_->end_pass(); } + + gpu_task_pool_.Push(current_task_); current_task_ = nullptr; gpu_free_channel_->Put(current_task_); timer.Pause(); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 6f785cad33e2d249629cb4dca5c3dca8bb65b08f..c1f83d2fe9274d47fd1fdfb2f5a4a033d1ef7ca7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -84,13 +84,14 @@ class PSGPUWrapper { const int batch_size); void BuildGPUTask(std::shared_ptr gpu_task); - void BuildTask(std::shared_ptr gpu_task); + void PreBuildTask(std::shared_ptr gpu_task); + void BuildPull(std::shared_ptr gpu_task); void LoadIntoMemory(bool is_shuffle); void BeginPass(); void EndPass(); void start_build_thread(); - void build_cpu_thread(); - void build_gpu_thread(); + void pre_build_thread(); + void build_thread(); void Finalize() { VLOG(3) << "PSGPUWrapper Begin Finalize."; @@ -102,10 +103,10 @@ class PSGPUWrapper { gpu_free_channel_->Close(); train_ready_channel_->Close(); running_ = false; - VLOG(3) << "begin stop build_cpu_threads_"; - build_cpu_threads_.join(); - VLOG(3) << "begin stop build_gpu_threads_"; - build_gpu_threads_.join(); + VLOG(3) << "begin stop pre_build_threads_"; + pre_build_threads_.join(); + VLOG(3) << "begin stop build_threads_"; + build_threads_.join(); s_instance_ = nullptr; VLOG(3) << "PSGPUWrapper Finalize Finished."; } @@ -310,8 +311,8 @@ class PSGPUWrapper { train_ready_channel_ = paddle::framework::MakeChannel>(); std::shared_ptr current_task_ = nullptr; - std::thread build_cpu_threads_; - std::thread build_gpu_threads_; + std::thread pre_build_threads_; + std::thread build_threads_; bool running_ = false; protected: