提交 f8f80db1 编写于 作者: T typhoonzero

update for multi trainer

上级 e13e15d8
......@@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase {
}
std::string GetGradVarNameForTrainer(const std::string &varname) const {
if (grads_counter_.find(varname) != grads_counter_.end()) {
if (grads_counter_.find(varname) == grads_counter_.end()) {
grads_counter_[varname] = 0;
}
char ret[256];
......@@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase {
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (trainer_count > 1) {
auto *var = recv_scope.FindVar(grad_var_name);
if (var != nullptr) {
// must rename the var to different names to merge gradient.
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.Var(grad_var_name);
......
......@@ -183,11 +183,20 @@ class DistributeTranspiler:
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
if opt_op.inputs.has_key("Grad"):
if opt_op.inputs["Grad"].name in grad_var_names:
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
else:
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
pserver_program.global_block().append_op(
type="recv",
inputs={"RX":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册