未验证 提交 88dfb30f 编写于 作者: D danleifeng 提交者: GitHub

fix hogwild_worker init_place bug (#33078)

* fix hogwild_worker dev_ctx place bug; test=develop
上级 9f6e5fdb
...@@ -195,6 +195,9 @@ class DeviceWorker { ...@@ -195,6 +195,9 @@ class DeviceWorker {
virtual void SetReaderPlace(const paddle::platform::Place& place) { virtual void SetReaderPlace(const paddle::platform::Place& place) {
device_reader_->SetPlace(place); device_reader_->SetPlace(place);
} }
virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) {
dev_ctx_ = dev_ctx;
}
virtual Scope* GetThreadScope() { return thread_scope_; } virtual Scope* GetThreadScope() { return thread_scope_; }
DataFeed* device_reader_ = nullptr; DataFeed* device_reader_ = nullptr;
...@@ -221,6 +224,7 @@ class DeviceWorker { ...@@ -221,6 +224,7 @@ class DeviceWorker {
int dump_mode_ = 0; int dump_mode_ = 0;
int dump_interval_ = 10000; int dump_interval_ = 10000;
ChannelWriter<std::string> writer_; ChannelWriter<std::string> writer_;
platform::DeviceContext* dev_ctx_ = nullptr;
}; };
class CPUWorkerBase : public DeviceWorker { class CPUWorkerBase : public DeviceWorker {
...@@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_; HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_; std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_; std::map<std::string, int> stat_var_name_map_;
#ifdef PADDLE_WITH_HETERPS
platform::DeviceContext* dev_ctx_ = nullptr;
#endif
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
...@@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker { ...@@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker {
gpuStream_t copy_stream_; gpuStream_t copy_stream_;
int batch_cnt_{0}; int batch_cnt_{0};
std::atomic<int> done_cnt_{0}; std::atomic<int> done_cnt_{0};
platform::DeviceContext* dev_ctx_ = nullptr;
double total_time_; double total_time_;
double read_time_; double read_time_;
......
...@@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { ...@@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
for (int i = 0; i < param_.stat_var_names_size(); ++i) { for (int i = 0; i < param_.stat_var_names_size(); ++i) {
stat_var_name_map_[param_.stat_var_names(i)] = 1; stat_var_name_map_[param_.stat_var_names(i)] = 1;
} }
#ifdef PADDLE_WITH_HETERPS
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
#endif
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
......
...@@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
workers_[i]->SetPlace(places_[i]); workers_[i]->SetPlace(places_[i]);
workers_[i]->SetReaderPlace(places_[i]); workers_[i]->SetReaderPlace(places_[i]);
workers_[i]->SetDeviceContext(
platform::DeviceContextPool::Instance().Get(places_[i]));
#else #else
workers_[i]->SetPlace(place); workers_[i]->SetPlace(place);
workers_[i]->SetReaderPlace(place); workers_[i]->SetReaderPlace(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册