提交 5e6276ed 编写于 作者: T typhoonzero

fix transpiler bug

上级 1eec9261
...@@ -68,7 +68,7 @@ class SendOp : public framework::OperatorBase { ...@@ -68,7 +68,7 @@ class SendOp : public framework::OperatorBase {
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
...@@ -77,20 +77,20 @@ class SendOp : public framework::OperatorBase { ...@@ -77,20 +77,20 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "batch barrier, ep: " << ep; VLOG(2) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (outs.size() > 0) { if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
// tell pservers that current trainer have called fetch // tell pservers that current trainer have called fetch
for (auto& ep : endpoints) { for (auto& ep : endpoints) {
VLOG(3) << "send fetch barrier, ep: " << ep; VLOG(2) << "send fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rpc_client->AsyncSendFetchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
......
...@@ -563,6 +563,8 @@ class DistributeTranspiler: ...@@ -563,6 +563,8 @@ class DistributeTranspiler:
orig_var_name = "" orig_var_name = ""
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = varname[:suff_idx] orig_var_name = varname[:suff_idx]
else:
orig_var_name = varname
return orig_var_name return orig_var_name
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
...@@ -577,7 +579,8 @@ class DistributeTranspiler: ...@@ -577,7 +579,8 @@ class DistributeTranspiler:
grad_block = None grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]: for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var( if same_or_split_var(
self._orig_varname(g.name), opt_op.input(key)[0]): self._orig_varname(g.name),
self._orig_varname(opt_op.input(key)[0])):
grad_block = g grad_block = g
break break
if not grad_block: if not grad_block:
...@@ -748,7 +751,7 @@ class DistributeTranspiler: ...@@ -748,7 +751,7 @@ class DistributeTranspiler:
param_names = [ param_names = [
p.name for p in self.param_grad_ep_mapping[endpoint]["params"] p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
] ]
if op.input("Param") in param_names: if op.input("Param")[0] in param_names:
return True return True
else: else:
for n in param_names: for n in param_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册