提交 5e6276ed 编写于 作者: T typhoonzero

fix transpiler bug

上级 1eec9261
......@@ -68,7 +68,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
......@@ -77,20 +77,20 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep;
VLOG(2) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch
for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep;
VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
......
......@@ -563,6 +563,8 @@ class DistributeTranspiler:
orig_var_name = ""
if suff_idx >= 0:
orig_var_name = varname[:suff_idx]
else:
orig_var_name = varname
return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
......@@ -577,7 +579,8 @@ class DistributeTranspiler:
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(
self._orig_varname(g.name), opt_op.input(key)[0]):
self._orig_varname(g.name),
self._orig_varname(opt_op.input(key)[0])):
grad_block = g
break
if not grad_block:
......@@ -748,7 +751,7 @@ class DistributeTranspiler:
param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
if op.input("Param") in param_names:
if op.input("Param")[0] in param_names:
return True
else:
for n in param_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册