diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 84369011476c77765dc5396830adc34f775fbb50..db83cd55889c43feadaab2dd4170b5e90d117435 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -195,6 +195,9 @@ class DeviceWorker { virtual void SetReaderPlace(const paddle::platform::Place& place) { device_reader_->SetPlace(place); } + virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) { + dev_ctx_ = dev_ctx; + } virtual Scope* GetThreadScope() { return thread_scope_; } DataFeed* device_reader_ = nullptr; @@ -221,6 +224,7 @@ class DeviceWorker { int dump_mode_ = 0; int dump_interval_ = 10000; ChannelWriter writer_; + platform::DeviceContext* dev_ctx_ = nullptr; }; class CPUWorkerBase : public DeviceWorker { @@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase { HogwildWorkerParameter param_; std::vector skip_ops_; std::map stat_var_name_map_; -#ifdef PADDLE_WITH_HETERPS - platform::DeviceContext* dev_ctx_ = nullptr; -#endif }; class DownpourWorker : public HogwildWorker { @@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker { gpuStream_t copy_stream_; int batch_cnt_{0}; std::atomic done_cnt_{0}; - platform::DeviceContext* dev_ctx_ = nullptr; double total_time_; double read_time_; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index b2d170888e28fc4e9918c26f000a5983c09811ee..0c66622ed7b9a6a6e9fb5112001009c2b95e367a 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { for (int i = 0; i < param_.stat_var_names_size(); ++i) { stat_var_name_map_[param_.stat_var_names(i)] = 1; } -#ifdef PADDLE_WITH_HETERPS - dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); -#endif } void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 7afa76c3fbd23a395e6769d83e939e0d36424471..c0ccc196348a5761ea4dedf1aab5ce8754eb74b5 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, #ifdef PADDLE_WITH_HETERPS workers_[i]->SetPlace(places_[i]); workers_[i]->SetReaderPlace(places_[i]); + workers_[i]->SetDeviceContext( + platform::DeviceContextPool::Instance().Get(places_[i])); #else workers_[i]->SetPlace(place); workers_[i]->SetReaderPlace(place);