From f8f80db163da76e5d0b01da54b496ee1a7236773 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 15 Dec 2017 19:24:44 +0800 Subject: [PATCH] update for multi trainer --- paddle/operators/recv_op.cc | 8 ++------ .../paddle/v2/fluid/distribute_transpiler.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 2ff6f42c9..07e66492e 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase { } std::string GetGradVarNameForTrainer(const std::string &varname) const { - if (grads_counter_.find(varname) != grads_counter_.end()) { + if (grads_counter_.find(varname) == grads_counter_.end()) { grads_counter_[varname] = 0; } char ret[256]; @@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase { VLOG(10) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; if (trainer_count > 1) { - auto *var = recv_scope.FindVar(grad_var_name); - if (var != nullptr) { - // must rename the var to different names to merge gradient. - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); } auto *var = recv_scope.Var(grad_var_name); diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 739b47cd2..4919dce20 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -183,11 +183,20 @@ class DistributeTranspiler: persistable=var.persistable, dtype=var.dtype, shape=var.shape) - optimize_sub_program.global_block().append_op( - type=opt_op.type, - inputs=opt_op.inputs, - outputs=opt_op.outputs, - attrs=opt_op.attrs) + + if opt_op.inputs.has_key("Grad"): + if opt_op.inputs["Grad"].name in grad_var_names: + optimize_sub_program.global_block().append_op( + type=opt_op.type, + inputs=opt_op.inputs, + outputs=opt_op.outputs, + attrs=opt_op.attrs) + else: + optimize_sub_program.global_block().append_op( + type=opt_op.type, + inputs=opt_op.inputs, + outputs=opt_op.outputs, + attrs=opt_op.attrs) pserver_program.global_block().append_op( type="recv", inputs={"RX": -- GitLab