未验证 提交 62ec644f 编写于 作者: D danleifeng 提交者: GitHub

[psgpu]fix pipe bug:save and pull overlap; test=develop (#37233)

上级 f29a3c68
......@@ -490,9 +490,8 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
void PSGPUWrapper::start_build_thread() {
running_ = true;
VLOG(3) << "start build CPU&GPU ps thread.";
VLOG(3) << "start build CPU ps thread.";
pre_build_threads_ = std::thread([this] { pre_build_thread(); });
build_threads_ = std::thread([this] { build_thread(); });
}
void PSGPUWrapper::pre_build_thread() {
......@@ -515,30 +514,28 @@ void PSGPUWrapper::pre_build_thread() {
VLOG(3) << "build cpu thread end";
}
void PSGPUWrapper::build_thread() {
// build: build_pull + build_gputask
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!gpu_free_channel_->Get(gpu_task)) {
continue;
}
if (!buildcpu_ready_channel_->Get(gpu_task)) {
continue;
}
VLOG(3) << "thread BuildGPUTask start.";
platform::Timer timer;
timer.Start();
BuildPull(gpu_task);
timer.Pause();
timer.Start();
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";
train_ready_channel_->Put(gpu_task);
void PSGPUWrapper::build_task() {
// build_task: build_pull + build_gputask
std::shared_ptr<HeterContext> gpu_task = nullptr;
// train end, gpu free
if (!gpu_free_channel_->Get(gpu_task)) {
return;
}
// ins and pre_build end
if (!buildcpu_ready_channel_->Get(gpu_task)) {
return;
}
VLOG(3) << "build gpu thread end";
VLOG(1) << "BuildPull start.";
platform::Timer timer;
timer.Start();
BuildPull(gpu_task);
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "BuildPull + BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";
current_task_ = gpu_task;
}
void PSGPUWrapper::BeginPass() {
......@@ -548,11 +545,15 @@ void PSGPUWrapper::BeginPass() {
PADDLE_THROW(
platform::errors::Fatal("[BeginPass] current task is not ended."));
}
// load+build done
if (!train_ready_channel_->Get(current_task_)) {
PADDLE_THROW(platform::errors::Fatal("train_ready_channel_ failed."));
}
build_task();
timer.Pause();
if (current_task_ == nullptr) {
PADDLE_THROW(platform::errors::Fatal(
"[BeginPass] after build_task, current task is not null."));
}
VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
}
......
......@@ -91,7 +91,7 @@ class PSGPUWrapper {
void EndPass();
void start_build_thread();
void pre_build_thread();
void build_thread();
void build_task();
void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize.";
......@@ -101,7 +101,6 @@ class PSGPUWrapper {
data_ready_channel_->Close();
buildcpu_ready_channel_->Close();
gpu_free_channel_->Close();
train_ready_channel_->Close();
running_ = false;
VLOG(3) << "begin stop pre_build_threads_";
pre_build_threads_.join();
......@@ -169,8 +168,6 @@ class PSGPUWrapper {
buildcpu_ready_channel_->SetCapacity(3);
gpu_free_channel_->Open();
gpu_free_channel_->SetCapacity(1);
train_ready_channel_->Open();
train_ready_channel_->SetCapacity(1);
current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
......@@ -306,10 +303,6 @@ class PSGPUWrapper {
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
gpu_free_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
train_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread pre_build_threads_;
std::thread build_threads_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册