未验证 提交 19dac67e 编写于 作者: T tangwei12 提交者: GitHub

fix distribute transpiler GRPC error code 4, RPC Deadline (#18984)

* fix sync mode hang in transpiler
* remove sync mode in send/recv
* replace PADDLE_ENFORCE with PADDLE_ENFORCE_NE
上级 5d1575cf
......@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
......@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase {
trainer_id);
recv_functor(rpc_ctx, scope);
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
......@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
......@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
}
}
}
......@@ -112,10 +108,6 @@ This operator can get variables from server side.
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
......
......@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync";
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X");
auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
......@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
......@@ -98,10 +95,6 @@ Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
......
......@@ -574,8 +574,7 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode,
]
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
......@@ -595,7 +594,6 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -669,8 +667,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
[param_varname, recv_op_role_var_name]
})
if self.sync_mode:
......@@ -1548,7 +1545,6 @@ class DistributeTranspiler(object):
if self.sync_mode else []
},
attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册