未验证 提交 e64fed86 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in pscore (#35698)

* add trainer desc config to distributed strategy

* code style modified

* data_feed set lod

* fix bug

* code style

* fix bug
上级 a4eadd15
...@@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() { ...@@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() {
// how to accumulate fetched values here // how to accumulate fetched values here
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
for (auto &op : ops_) { for (auto &op : ops_) {
bool need_skip = false; bool need_skip = false;
...@@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() { ...@@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() {
} }
} }
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}
total_ins_num += cur_batch; total_ins_num += cur_batch;
++batch_cnt;
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
timeline.Pause(); timeline.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num; << " seconds, ins_num: " << total_ins_num;
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
#if defined PADDLE_WITH_PSCORE #if defined PADDLE_WITH_PSCORE
if (thread_barrier_) { if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
......
...@@ -214,7 +214,7 @@ void MultiTrainer::Finalize() { ...@@ -214,7 +214,7 @@ void MultiTrainer::Finalize() {
if (need_dump_field_ || need_dump_param_) { if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv(); FinalizeDumpEnv();
} }
#ifdef PADDLE_WITH_HETERPS
for (size_t i = 0; i < need_merge_var_names_.size(); i++) { for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
if (root_var == nullptr) { if (root_var == nullptr) {
...@@ -222,7 +222,11 @@ void MultiTrainer::Finalize() { ...@@ -222,7 +222,11 @@ void MultiTrainer::Finalize() {
} }
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>(); LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
#ifdef PADDLE_WITH_HETERPS
for (size_t j = 0; j < places_.size(); j++) { for (size_t j = 0; j < places_.size(); j++) {
#else
for (int j = 1; j < thread_num_; j++) {
#endif
Scope* cur_thread_scope = workers_[j]->GetThreadScope(); Scope* cur_thread_scope = workers_[j]->GetThreadScope();
Variable* thread_var = Variable* thread_var =
cur_thread_scope->FindVar(need_merge_var_names_[i]); cur_thread_scope->FindVar(need_merge_var_names_[i]);
...@@ -246,8 +250,8 @@ void MultiTrainer::Finalize() { ...@@ -246,8 +250,8 @@ void MultiTrainer::Finalize() {
_ForEachDataType_(MergeCallback); _ForEachDataType_(MergeCallback);
} }
} }
#ifdef PADDLE_WITH_HETERPS
MergeDenseParam(); MergeDenseParam();
#endif #endif
root_scope_->DropKids(); root_scope_->DropKids();
} }
......
...@@ -91,7 +91,10 @@ class Hogwild(DeviceWorker): ...@@ -91,7 +91,10 @@ class Hogwild(DeviceWorker):
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
if self._infer: if self._infer:
# just ignore feed op for inference model # just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"]) trainer_desc.hogwild_param.skip_ops.extend([
"feed", "push_sparse", "push_sparse_v2", "push_dense",
"distributed_push_sparse", "send"
])
dense_table_set = set() dense_table_set = set()
program_id = str(id(self._program)) program_id = str(id(self._program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册