未验证 提交 88dfb30f 编写于 作者: D danleifeng 提交者: GitHub

fix hogwild_worker init_place bug (#33078)

* fix hogwild_worker dev_ctx place bug; test=develop
上级 9f6e5fdb
......@@ -195,6 +195,9 @@ class DeviceWorker {
virtual void SetReaderPlace(const paddle::platform::Place& place) {
device_reader_->SetPlace(place);
}
virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) {
dev_ctx_ = dev_ctx;
}
virtual Scope* GetThreadScope() { return thread_scope_; }
DataFeed* device_reader_ = nullptr;
......@@ -221,6 +224,7 @@ class DeviceWorker {
int dump_mode_ = 0;
int dump_interval_ = 10000;
ChannelWriter<std::string> writer_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
class CPUWorkerBase : public DeviceWorker {
......@@ -266,9 +270,6 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
#ifdef PADDLE_WITH_HETERPS
platform::DeviceContext* dev_ctx_ = nullptr;
#endif
};
class DownpourWorker : public HogwildWorker {
......@@ -622,7 +623,6 @@ class PSGPUWorker : public HogwildWorker {
gpuStream_t copy_stream_;
int batch_cnt_{0};
std::atomic<int> done_cnt_{0};
platform::DeviceContext* dev_ctx_ = nullptr;
double total_time_;
double read_time_;
......
......@@ -39,9 +39,6 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
for (int i = 0; i < param_.stat_var_names_size(); ++i) {
stat_var_name_map_[param_.stat_var_names(i)] = 1;
}
#ifdef PADDLE_WITH_HETERPS
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
#endif
}
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
......
......@@ -112,6 +112,8 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
#ifdef PADDLE_WITH_HETERPS
workers_[i]->SetPlace(places_[i]);
workers_[i]->SetReaderPlace(places_[i]);
workers_[i]->SetDeviceContext(
platform::DeviceContextPool::Instance().Get(places_[i]));
#else
workers_[i]->SetPlace(place);
workers_[i]->SetReaderPlace(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册