From 88dfb30f2e407eec4fb78c772be35c8bd38d28f0 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Tue, 25 May 2021 11:56:23 +0800 Subject: [PATCH] fix hogwild_worker init_place bug (#33078) * fix hogwild_worker dev_ctx place bug; test=develop --- paddle/fluid/framework/device_worker.h | 8 ++++---- paddle/fluid/framework/hogwild_worker.cc | 3 --- paddle/fluid/framework/multi_trainer.cc | 2 ++ 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 84369011476..db83cd55889 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -195,6 +195,9 @@ class DeviceWorker { virtual void SetReaderPlace(const paddle::platform::Place& place) { device_reader_->SetPlace(place); } + virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) { + dev_ctx_ = dev_ctx; + } virtual Scope* GetThreadScope() { return thread_scope_; } DataFeed* device_reader_ = nullptr; @@ -221,6 +224,7 @@ class DeviceWorker { int dump_mode_ = 0; int dump_interval_ = 10000; ChannelWriter writer_; + platform::DeviceContext* dev_ctx_ = nullptr; }; class CPUWorkerBase : public DeviceWorker { @@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase { HogwildWorkerParameter param_; std::vector skip_ops_; std::map stat_var_name_map_; -#ifdef PADDLE_WITH_HETERPS - platform::DeviceContext* dev_ctx_ = nullptr; -#endif }; class DownpourWorker : public HogwildWorker { @@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker { gpuStream_t copy_stream_; int batch_cnt_{0}; std::atomic done_cnt_{0}; - platform::DeviceContext* dev_ctx_ = nullptr; double total_time_; double read_time_; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index b2d170888e2..0c66622ed7b 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { for (int i = 0; i < param_.stat_var_names_size(); ++i) { stat_var_name_map_[param_.stat_var_names(i)] = 1; } -#ifdef PADDLE_WITH_HETERPS - dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); -#endif } void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 7afa76c3fbd..c0ccc196348 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, #ifdef PADDLE_WITH_HETERPS workers_[i]->SetPlace(places_[i]); workers_[i]->SetReaderPlace(places_[i]); + workers_[i]->SetDeviceContext( + platform::DeviceContextPool::Instance().Get(places_[i])); #else workers_[i]->SetPlace(place); workers_[i]->SetReaderPlace(place); -- GitLab