diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 4fb98e526d5fc45aeb9f674dbb2a3447464e2819..0d0c9048d00f3b49d15a84a9e789260dda7d65b6 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -490,9 +490,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { void PSGPUWrapper::start_build_thread() { running_ = true; - VLOG(3) << "start build CPU&GPU ps thread."; + VLOG(3) << "start build CPU ps thread."; pre_build_threads_ = std::thread([this] { pre_build_thread(); }); - build_threads_ = std::thread([this] { build_thread(); }); } void PSGPUWrapper::pre_build_thread() { @@ -515,30 +514,28 @@ void PSGPUWrapper::pre_build_thread() { VLOG(3) << "build cpu thread end"; } -void PSGPUWrapper::build_thread() { - // build: build_pull + build_gputask - while (running_) { - std::shared_ptr gpu_task = nullptr; - if (!gpu_free_channel_->Get(gpu_task)) { - continue; - } - if (!buildcpu_ready_channel_->Get(gpu_task)) { - continue; - } - VLOG(3) << "thread BuildGPUTask start."; - platform::Timer timer; - timer.Start(); - BuildPull(gpu_task); - timer.Pause(); - timer.Start(); - BuildGPUTask(gpu_task); - timer.Pause(); - VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec() - << "s"; - - train_ready_channel_->Put(gpu_task); +void PSGPUWrapper::build_task() { + // build_task: build_pull + build_gputask + std::shared_ptr gpu_task = nullptr; + // train end, gpu free + if (!gpu_free_channel_->Get(gpu_task)) { + return; + } + // ins and pre_build end + if (!buildcpu_ready_channel_->Get(gpu_task)) { + return; } - VLOG(3) << "build gpu thread end"; + + VLOG(1) << "BuildPull start."; + platform::Timer timer; + timer.Start(); + BuildPull(gpu_task); + BuildGPUTask(gpu_task); + timer.Pause(); + VLOG(1) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec() + << "s"; + + current_task_ = gpu_task; } void PSGPUWrapper::BeginPass() { @@ -548,11 +545,15 @@ void PSGPUWrapper::BeginPass() { PADDLE_THROW( platform::errors::Fatal("[BeginPass] current task is not ended.")); } - // load+build done - if (!train_ready_channel_->Get(current_task_)) { - PADDLE_THROW(platform::errors::Fatal("train_ready_channel_ failed.")); - } + + build_task(); timer.Pause(); + + if (current_task_ == nullptr) { + PADDLE_THROW(platform::errors::Fatal( + "[BeginPass] after build_task, current task is not null.")); + } + VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index c1f83d2fe9274d47fd1fdfb2f5a4a033d1ef7ca7..b726a629586e18a58acac4d9995c2fbd4e7a12e5 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -91,7 +91,7 @@ class PSGPUWrapper { void EndPass(); void start_build_thread(); void pre_build_thread(); - void build_thread(); + void build_task(); void Finalize() { VLOG(3) << "PSGPUWrapper Begin Finalize."; @@ -101,7 +101,6 @@ class PSGPUWrapper { data_ready_channel_->Close(); buildcpu_ready_channel_->Close(); gpu_free_channel_->Close(); - train_ready_channel_->Close(); running_ = false; VLOG(3) << "begin stop pre_build_threads_"; pre_build_threads_.join(); @@ -169,8 +168,6 @@ class PSGPUWrapper { buildcpu_ready_channel_->SetCapacity(3); gpu_free_channel_->Open(); gpu_free_channel_->SetCapacity(1); - train_ready_channel_->Open(); - train_ready_channel_->SetCapacity(1); current_task_ = nullptr; gpu_free_channel_->Put(current_task_); @@ -306,10 +303,6 @@ class PSGPUWrapper { paddle::framework::ChannelObject>> gpu_free_channel_ = paddle::framework::MakeChannel>(); - std::shared_ptr< - paddle::framework::ChannelObject>> - train_ready_channel_ = - paddle::framework::MakeChannel>(); std::shared_ptr current_task_ = nullptr; std::thread pre_build_threads_; std::thread build_threads_;