提交 f8f80db1 编写于 作者: T typhoonzero

update for multi trainer

上级 e13e15d8
...@@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase {
} }
std::string GetGradVarNameForTrainer(const std::string &varname) const { std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) != grads_counter_.end()) { if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0; grads_counter_[varname] = 0;
} }
char ret[256]; char ret[256];
...@@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase {
VLOG(10) << "recved grad: " << grad_var_name VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
if (trainer_count > 1) { if (trainer_count > 1) {
auto *var = recv_scope.FindVar(grad_var_name); grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
if (var != nullptr) {
// must rename the var to different names to merge gradient.
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
} }
auto *var = recv_scope.Var(grad_var_name); auto *var = recv_scope.Var(grad_var_name);
......
...@@ -183,11 +183,20 @@ class DistributeTranspiler: ...@@ -183,11 +183,20 @@ class DistributeTranspiler:
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
optimize_sub_program.global_block().append_op(
type=opt_op.type, if opt_op.inputs.has_key("Grad"):
inputs=opt_op.inputs, if opt_op.inputs["Grad"].name in grad_var_names:
outputs=opt_op.outputs, optimize_sub_program.global_block().append_op(
attrs=opt_op.attrs) type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
else:
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
inputs={"RX": inputs={"RX":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册