未验证 提交 19dac67e 编写于 作者: T tangwei12 提交者: GitHub

fix distribute transpiler GRPC error code 4, RPC Deadline (#18984)

* fix sync mode hang in transpiler
* remove sync mode in send/recv
* replace PADDLE_ENFORCE with PADDLE_ENFORCE_NE
上级 5d1575cf
...@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep; VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames = std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames"); Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier"); bool with_barrier = Attr<bool>("with_barrier");
...@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase {
trainer_id); trainer_id);
recv_functor(rpc_ctx, scope); recv_functor(rpc_ctx, scope);
} else { } else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) { if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
...@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase {
rets.push_back( rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
} }
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else { } else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
...@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope, rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i])); varname, outs[i]));
} }
for (size_t i = 0; i < rets.size(); i++) { }
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
} VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
} }
} }
} }
...@@ -112,10 +108,6 @@ This operator can get variables from server side. ...@@ -112,10 +108,6 @@ This operator can get variables from server side.
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier", AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use " "(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately") "AsyncGetVarNoBarrier get variable from pserver immediately")
......
...@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync"; VLOG(3) << "SendBarrierOp sync";
// need to wait before sending send_barrier message std::vector<distributed::VarHandlePtr> rets;
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase { ...@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id"); auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames"); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
...@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase { ...@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
if (sync_send) { for (size_t i = 0; i < rets.size(); i++) {
for (size_t i = 0; i < rets.size(); i++) { VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
} }
} }
} }
...@@ -98,10 +95,6 @@ Send operator ...@@ -98,10 +95,6 @@ Send operator
This operator will send variables to listen_and_serve op at the parameter server. This operator will send variables to listen_and_serve op at the parameter server.
)DOC"); )DOC");
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -574,8 +574,7 @@ class DistributeTranspiler(object): ...@@ -574,8 +574,7 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname], self.grad_name_to_param_name[grad_varname],
splited_grad_varname splited_grad_varname
], ]
"sync_mode": not self.sync_mode,
}) })
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
send_vars.append(var) send_vars.append(var)
...@@ -595,7 +594,6 @@ class DistributeTranspiler(object): ...@@ -595,7 +594,6 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out}, outputs={"Out": send_barrier_out},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -669,8 +667,7 @@ class DistributeTranspiler(object): ...@@ -669,8 +667,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name], [param_varname, recv_op_role_var_name]
"sync_mode": not self.sync_mode
}) })
if self.sync_mode: if self.sync_mode:
...@@ -1548,7 +1545,6 @@ class DistributeTranspiler(object): ...@@ -1548,7 +1545,6 @@ class DistributeTranspiler(object):
if self.sync_mode else [] if self.sync_mode else []
}, },
attrs={ attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册