From e64fed86c8c09cd40e62e6312665f8b35e20c374 Mon Sep 17 00:00:00 2001 From: wangguanqun Date: Thu, 16 Sep 2021 22:31:05 +0800 Subject: [PATCH] fix bug in pscore (#35698) * add trainer desc config to distributed strategy * code style modified * data_feed set lod * fix bug * code style * fix bug --- paddle/fluid/framework/hogwild_worker.cc | 14 ++++++++++++++ paddle/fluid/framework/multi_trainer.cc | 8 ++++++-- python/paddle/fluid/device_worker.py | 5 ++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 0c66622ed7b..f4660751b58 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() { // how to accumulate fetched values here device_reader_->Start(); int cur_batch; + int batch_cnt = 0; while ((cur_batch = device_reader_->Next()) > 0) { for (auto &op : ops_) { bool need_skip = false; @@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() { } } + if (need_dump_field_) { + DumpField(*thread_scope_, dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*thread_scope_, batch_cnt); + } + total_ins_num += cur_batch; + ++batch_cnt; PrintFetchVars(); thread_scope_->DropKids(); } timeline.Pause(); VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() << " seconds, ins_num: " << total_ins_num; + + if (need_dump_field_ || need_dump_param_) { + writer_.Flush(); + } + #if defined PADDLE_WITH_PSCORE if (thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index c0ccc196348..2a022ea4bb9 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -214,7 +214,7 @@ void MultiTrainer::Finalize() { if (need_dump_field_ || need_dump_param_) { FinalizeDumpEnv(); } -#ifdef PADDLE_WITH_HETERPS + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); if (root_var == nullptr) { @@ -222,7 +222,11 @@ void MultiTrainer::Finalize() { } LoDTensor* root_tensor = root_var->GetMutable(); +#ifdef PADDLE_WITH_HETERPS for (size_t j = 0; j < places_.size(); j++) { +#else + for (int j = 1; j < thread_num_; j++) { +#endif Scope* cur_thread_scope = workers_[j]->GetThreadScope(); Variable* thread_var = cur_thread_scope->FindVar(need_merge_var_names_[i]); @@ -246,8 +250,8 @@ void MultiTrainer::Finalize() { _ForEachDataType_(MergeCallback); } } +#ifdef PADDLE_WITH_HETERPS MergeDenseParam(); - #endif root_scope_->DropKids(); } diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 7fed27ee459..a246474e21e 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -91,7 +91,10 @@ class Hogwild(DeviceWorker): trainer_desc.device_worker_name = "HogwildWorker" if self._infer: # just ignore feed op for inference model - trainer_desc.hogwild_param.skip_ops.extend(["feed"]) + trainer_desc.hogwild_param.skip_ops.extend([ + "feed", "push_sparse", "push_sparse_v2", "push_dense", + "distributed_push_sparse", "send" + ]) dense_table_set = set() program_id = str(id(self._program)) -- GitLab