未验证 提交 ded3e705 编写于 作者: D danleifeng 提交者: GitHub

[heterps]fix heterps pipeline training (#36512)

* split into PreBuildTask and BuildPull; slove endpass bug;test=develop

* change buildcpu into prebuild and buildcpu into build;test=develop
上级 6a3941e3
......@@ -40,7 +40,7 @@ namespace framework {
std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL;
bool PSGPUWrapper::is_initialized_ = false;
void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
platform::Timer timeline;
timeline.Start();
......@@ -49,17 +49,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
auto& device_keys = gpu_task->device_keys_;
auto& device_vals = gpu_task->device_values_;
auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads;
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr = FleetWrapper::GetInstance();
#endif
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = paddle::distributed::Communicator::GetInstance();
#endif
// data should be in input channel
thread_keys_.resize(thread_keys_thread_num_);
......@@ -181,6 +171,25 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
VLOG(3) << "GpuPs shard: " << i << " key len: " << local_keys[i].size();
local_ptr[i].resize(local_keys[i].size());
}
}
void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline;
int device_num = heter_devices_.size();
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
auto& device_keys = gpu_task->device_keys_;
auto& device_vals = gpu_task->device_values_;
auto& device_mutex = gpu_task->mutex_;
std::vector<std::thread> threads(thread_keys_shard_num_);
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr = FleetWrapper::GetInstance();
#endif
#ifdef PADDLE_WITH_PSCORE
auto fleet_ptr = paddle::distributed::Communicator::GetInstance();
#endif
#ifdef PADDLE_WITH_PSLIB
// get day_id: day nums from 1970
......@@ -482,29 +491,32 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
void PSGPUWrapper::start_build_thread() {
running_ = true;
VLOG(3) << "start build CPU&GPU ps thread.";
build_cpu_threads_ = std::thread([this] { build_cpu_thread(); });
build_gpu_threads_ = std::thread([this] { build_gpu_thread(); });
pre_build_threads_ = std::thread([this] { pre_build_thread(); });
build_threads_ = std::thread([this] { build_thread(); });
}
void PSGPUWrapper::build_cpu_thread() {
void PSGPUWrapper::pre_build_thread() {
// prebuild: process load_data
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!data_ready_channel_->Get(gpu_task)) {
continue;
}
VLOG(3) << "thread BuildTask start.";
VLOG(3) << "thread PreBuildTask start.";
platform::Timer timer;
timer.Start();
// build cpu ps data process
BuildTask(gpu_task);
PreBuildTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildTask end, cost time: " << timer.ElapsedSec() << "s";
VLOG(1) << "thread PreBuildTask end, cost time: " << timer.ElapsedSec()
<< "s";
buildcpu_ready_channel_->Put(gpu_task);
}
VLOG(3) << "build cpu thread end";
}
void PSGPUWrapper::build_gpu_thread() {
void PSGPUWrapper::build_thread() {
// build: build_pull + build_gputask
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!gpu_free_channel_->Get(gpu_task)) {
......@@ -516,12 +528,14 @@ void PSGPUWrapper::build_gpu_thread() {
VLOG(3) << "thread BuildGPUTask start.";
platform::Timer timer;
timer.Start();
BuildPull(gpu_task);
timer.Pause();
timer.Start();
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";
gpu_task_pool_.Push(gpu_task);
train_ready_channel_->Put(gpu_task);
}
VLOG(3) << "build gpu thread end";
......@@ -557,6 +571,8 @@ void PSGPUWrapper::EndPass() {
if (keysize_max != 0) {
HeterPs_->end_pass();
}
gpu_task_pool_.Push(current_task_);
current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
timer.Pause();
......
......@@ -84,13 +84,14 @@ class PSGPUWrapper {
const int batch_size);
void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void BuildTask(std::shared_ptr<HeterContext> gpu_task);
void PreBuildTask(std::shared_ptr<HeterContext> gpu_task);
void BuildPull(std::shared_ptr<HeterContext> gpu_task);
void LoadIntoMemory(bool is_shuffle);
void BeginPass();
void EndPass();
void start_build_thread();
void build_cpu_thread();
void build_gpu_thread();
void pre_build_thread();
void build_thread();
void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize.";
......@@ -102,10 +103,10 @@ class PSGPUWrapper {
gpu_free_channel_->Close();
train_ready_channel_->Close();
running_ = false;
VLOG(3) << "begin stop build_cpu_threads_";
build_cpu_threads_.join();
VLOG(3) << "begin stop build_gpu_threads_";
build_gpu_threads_.join();
VLOG(3) << "begin stop pre_build_threads_";
pre_build_threads_.join();
VLOG(3) << "begin stop build_threads_";
build_threads_.join();
s_instance_ = nullptr;
VLOG(3) << "PSGPUWrapper Finalize Finished.";
}
......@@ -310,8 +311,8 @@ class PSGPUWrapper {
train_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread build_cpu_threads_;
std::thread build_gpu_threads_;
std::thread pre_build_threads_;
std::thread build_threads_;
bool running_ = false;
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册