未验证 提交 ded3e705 编写于 作者: D danleifeng 提交者: GitHub

[heterps]fix heterps pipeline training (#36512)

* split into PreBuildTask and BuildPull; slove endpass bug;test=develop

* change buildcpu into prebuild and buildcpu into build;test=develop
上级 6a3941e3
...@@ -40,7 +40,7 @@ namespace framework { ...@@ -40,7 +40,7 @@ namespace framework {
std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL; std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL;
bool PSGPUWrapper::is_initialized_ = false; bool PSGPUWrapper::is_initialized_ = false;
void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) { void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
...@@ -49,17 +49,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -49,17 +49,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
auto& local_keys = gpu_task->feature_keys_; auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_; auto& local_ptr = gpu_task->value_ptr_;
auto& device_keys = gpu_task->device_keys_;
auto& device_vals = gpu_task->device_values_;
auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads; std::vector<std::thread> threads;
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr = FleetWrapper::GetInstance();
#endif
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = paddle::distributed::Communicator::GetInstance();
#endif
// data should be in input channel // data should be in input channel
thread_keys_.resize(thread_keys_thread_num_); thread_keys_.resize(thread_keys_thread_num_);
...@@ -181,6 +171,25 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -181,6 +171,25 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size(); VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size()); local_ptr[i].resize(local_keys[i].size());
} }
}
void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline;
int device_num = heter_devices_.size();
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
auto& device_keys = gpu_task->device_keys_;
auto& device_vals = gpu_task->device_values_;
auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads(thread_keys_shard_num_);
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr = FleetWrapper::GetInstance();
#endif
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = paddle::distributed::Communicator::GetInstance();
#endif
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
// get day_id: day nums from 1970 // get day_id: day nums from 1970
...@@ -482,29 +491,32 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { ...@@ -482,29 +491,32 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
void PSGPUWrapper::start_build_thread() { void PSGPUWrapper::start_build_thread() {
running_ = true; running_ = true;
VLOG(3) << "start build CPU&GPU ps thread."; VLOG(3) << "start build CPU&GPU ps thread.";
build_cpu_threads_ = std::thread([this] { build_cpu_thread(); }); pre_build_threads_ = std::thread([this] { pre_build_thread(); });
build_gpu_threads_ = std::thread([this] { build_gpu_thread(); }); build_threads_ = std::thread([this] { build_thread(); });
} }
void PSGPUWrapper::build_cpu_thread() { void PSGPUWrapper::pre_build_thread() {
// prebuild: process load_data
while (running_) { while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr; std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!data_ready_channel_->Get(gpu_task)) { if (!data_ready_channel_->Get(gpu_task)) {
continue; continue;
} }
VLOG(3) << "thread BuildTask start."; VLOG(3) << "thread PreBuildTask start.";
platform::Timer timer; platform::Timer timer;
timer.Start(); timer.Start();
// build cpu ps data process // build cpu ps data process
BuildTask(gpu_task); PreBuildTask(gpu_task);
timer.Pause(); timer.Pause();
VLOG(1) << "thread BuildTask end, cost time: " << timer.ElapsedSec() << "s"; VLOG(1) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec()
<< "s";
buildcpu_ready_channel_->Put(gpu_task); buildcpu_ready_channel_->Put(gpu_task);
} }
VLOG(3) << "build cpu thread end"; VLOG(3) << "build cpu thread end";
} }
void PSGPUWrapper::build_gpu_thread() { void PSGPUWrapper::build_thread() {
// build: build_pull + build_gputask
while (running_) { while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr; std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!gpu_free_channel_->Get(gpu_task)) { if (!gpu_free_channel_->Get(gpu_task)) {
...@@ -516,12 +528,14 @@ void PSGPUWrapper::build_gpu_thread() { ...@@ -516,12 +528,14 @@ void PSGPUWrapper::build_gpu_thread() {
VLOG(3) << "thread BuildGPUTask start."; VLOG(3) << "thread BuildGPUTask start.";
platform::Timer timer; platform::Timer timer;
timer.Start(); timer.Start();
BuildPull(gpu_task);
timer.Pause();
timer.Start();
BuildGPUTask(gpu_task); BuildGPUTask(gpu_task);
timer.Pause(); timer.Pause();
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec() VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s"; << "s";
gpu_task_pool_.Push(gpu_task);
train_ready_channel_->Put(gpu_task); train_ready_channel_->Put(gpu_task);
} }
VLOG(3) << "build gpu thread end"; VLOG(3) << "build gpu thread end";
...@@ -557,6 +571,8 @@ void PSGPUWrapper::EndPass() { ...@@ -557,6 +571,8 @@ void PSGPUWrapper::EndPass() {
if (keysize_max != 0) { if (keysize_max != 0) {
HeterPs_->end_pass(); HeterPs_->end_pass();
} }
gpu_task_pool_.Push(current_task_);
current_task_ = nullptr; current_task_ = nullptr;
gpu_free_channel_->Put(current_task_); gpu_free_channel_->Put(current_task_);
timer.Pause(); timer.Pause();
......
...@@ -84,13 +84,14 @@ class PSGPUWrapper { ...@@ -84,13 +84,14 @@ class PSGPUWrapper {
const int batch_size); const int batch_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task); void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void BuildTask(std::shared_ptr<HeterContext> gpu_task); void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
void BuildPull(std::shared_ptr<HeterContext> gpu_task);
void LoadIntoMemory(bool is_shuffle); void LoadIntoMemory(bool is_shuffle);
void BeginPass(); void BeginPass();
void EndPass(); void EndPass();
void start_build_thread(); void start_build_thread();
void build_cpu_thread(); void pre_build_thread();
void build_gpu_thread(); void build_thread();
void Finalize() { void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize."; VLOG(3) << "PSGPUWrapper Begin Finalize.";
...@@ -102,10 +103,10 @@ class PSGPUWrapper { ...@@ -102,10 +103,10 @@ class PSGPUWrapper {
gpu_free_channel_->Close(); gpu_free_channel_->Close();
train_ready_channel_->Close(); train_ready_channel_->Close();
running_ = false; running_ = false;
VLOG(3) << "begin stop build_cpu_threads_"; VLOG(3) << "begin stop pre_build_threads_";
build_cpu_threads_.join(); pre_build_threads_.join();
VLOG(3) << "begin stop build_gpu_threads_"; VLOG(3) << "begin stop build_threads_";
build_gpu_threads_.join(); build_threads_.join();
s_instance_ = nullptr; s_instance_ = nullptr;
VLOG(3) << "PSGPUWrapper Finalize Finished."; VLOG(3) << "PSGPUWrapper Finalize Finished.";
} }
...@@ -310,8 +311,8 @@ class PSGPUWrapper { ...@@ -310,8 +311,8 @@ class PSGPUWrapper {
train_ready_channel_ = train_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>(); paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<HeterContext> current_task_ = nullptr; std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread build_cpu_threads_; std::thread pre_build_threads_;
std::thread build_gpu_threads_; std::thread build_threads_;
bool running_ = false; bool running_ = false;
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册