未验证 提交 e64fed86 编写于 作者: W wangguanqun 提交者: GitHub

fix bug in pscore (#35698)

* add trainer desc config to distributed strategy

* code style modified

* data_feed set lod

* fix bug

* code style

* fix bug
上级 a4eadd15
......@@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() {
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
for (auto &op : ops_) {
bool need_skip = false;
......@@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() {
}
}
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}
total_ins_num += cur_batch;
++batch_cnt;
PrintFetchVars();
thread_scope_->DropKids();
}
timeline.Pause();
VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
<< " seconds, ins_num: " << total_ins_num;
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
#if defined PADDLE_WITH_PSCORE
if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
......
......@@ -214,7 +214,7 @@ void MultiTrainer::Finalize() {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
#ifdef PADDLE_WITH_HETERPS
for (size_t i = 0; i < need_merge_var_names_.size(); i++) {
Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]);
if (root_var == nullptr) {
......@@ -222,7 +222,11 @@ void MultiTrainer::Finalize() {
}
LoDTensor* root_tensor = root_var->GetMutable<LoDTensor>();
#ifdef PADDLE_WITH_HETERPS
for (size_t j = 0; j < places_.size(); j++) {
#else
for (int j = 1; j < thread_num_; j++) {
#endif
Scope* cur_thread_scope = workers_[j]->GetThreadScope();
Variable* thread_var =
cur_thread_scope->FindVar(need_merge_var_names_[i]);
......@@ -246,8 +250,8 @@ void MultiTrainer::Finalize() {
_ForEachDataType_(MergeCallback);
}
}
#ifdef PADDLE_WITH_HETERPS
MergeDenseParam();
#endif
root_scope_->DropKids();
}
......
......@@ -91,7 +91,10 @@ class Hogwild(DeviceWorker):
trainer_desc.device_worker_name = "HogwildWorker"
if self._infer:
# just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"])
trainer_desc.hogwild_param.skip_ops.extend([
"feed", "push_sparse", "push_sparse_v2", "push_dense",
"distributed_push_sparse", "send"
])
dense_table_set = set()
program_id = str(id(self._program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册