diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 2ff6f42c94c41bc5fe11dac477eb7f54a3b9d56a..07e66492e1411c75c9086a6c6ba88a834f1b840f 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase { } std::string GetGradVarNameForTrainer(const std::string &varname) const { - if (grads_counter_.find(varname) != grads_counter_.end()) { + if (grads_counter_.find(varname) == grads_counter_.end()) { grads_counter_[varname] = 0; } char ret[256]; @@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase { VLOG(10) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; if (trainer_count > 1) { - auto *var = recv_scope.FindVar(grad_var_name); - if (var != nullptr) { - // must rename the var to different names to merge gradient. - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); } auto *var = recv_scope.Var(grad_var_name); diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 739b47cd281fafdb2ab20bbbfd5a8c6d2ff4eda0..4919dce20d560e46358e59c716d305b8356e6168 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -183,11 +183,20 @@ class DistributeTranspiler: persistable=var.persistable, dtype=var.dtype, shape=var.shape) - optimize_sub_program.global_block().append_op( - type=opt_op.type, - inputs=opt_op.inputs, - outputs=opt_op.outputs, - attrs=opt_op.attrs) + + if opt_op.inputs.has_key("Grad"): + if opt_op.inputs["Grad"].name in grad_var_names: + optimize_sub_program.global_block().append_op( + type=opt_op.type, + inputs=opt_op.inputs, + outputs=opt_op.outputs, + attrs=opt_op.attrs) + else: + optimize_sub_program.global_block().append_op( + type=opt_op.type, + inputs=opt_op.inputs, + outputs=opt_op.outputs, + attrs=opt_op.attrs) pserver_program.global_block().append_op( type="recv", inputs={"RX":