未验证 提交 1bc285a5 编写于 作者: 1 123malin 提交者: GitHub

add retry function to try to solve grpc error code 14 (#19661)

* rpc retry for asycsend/get/prefetch

* test=develop, change retry vlog level to 3

* test=develop, set default grpc_retry_times is 3
上级 5eb381a3
...@@ -73,36 +73,53 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -73,36 +73,53 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
const std::string method = kSendRPC; const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { int retry_times_ = 0;
auto* var = p_scope->FindVar(var_name_val);
while (true) {
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
::grpc::ByteBuffer req; framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); auto* var = p_scope->FindVar(var_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
// stub context VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
s->response_call_back_ = nullptr;
platform::RecordRPCEvent record_event(method); // stub context
s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall( platform::RecordRPCEvent record_event(method);
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) { auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait(); h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
} }
});
req_count_++;
return h; return h;
}
} }
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
...@@ -169,42 +186,57 @@ VarHandlePtr GRPCClient::_AsyncGetVar( ...@@ -169,42 +186,57 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string table_name_val = table_name; const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); int retry_times_ = 0;
s->Prepare(h, time_out);
while (true) {
GetProcessor* s = new GetProcessor(ch);
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method, VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
p_ctx, h, rpc_path, this] { s->Prepare(h, time_out);
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// stub context VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method); // stub context
s->response_call_back_ = ProcGetResponse;
auto call = platform::RecordRPCEvent record_event(method);
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall(); auto call =
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) { if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait(); h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
} }
});
req_count_++;
return h; return h;
}
} }
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
...@@ -221,41 +253,55 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -221,41 +253,55 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const std::string table_name_val = table_name; const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = kPrefetchRPC; const std::string method = kPrefetchRPC;
int retry_times_ = 0;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); while (true) {
s->Prepare(h, time_out); GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
s, method, h, table_name_val, this] { p_ctx, s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
0, table_name_val); out_var_name_val, 0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) { if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait(); h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
} }
});
req_count_++; return h;
return h; }
} }
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
...@@ -420,6 +466,14 @@ void GRPCClient::Proceed() { ...@@ -420,6 +466,14 @@ void GRPCClient::Proceed() {
ok_ = false; ok_ = false;
} }
c->Finish(false); c->Finish(false);
} else if (c->status_.error_code() == grpc::StatusCode::UNAVAILABLE) {
VLOG(3) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details()
<< " should retry!";
c->GetVarHandlePtr()->should_retry = true;
c->Finish(false);
} else { } else {
LOG(FATAL) << c->GetVarHandlePtr()->String() LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code() << " meets grpc error, error_code:" << c->status_.error_code()
......
...@@ -85,6 +85,8 @@ class VarHandle { ...@@ -85,6 +85,8 @@ class VarHandle {
virtual ~VarHandle() {} virtual ~VarHandle() {}
public: public:
bool should_retry = false;
bool Wait() { bool Wait() {
int ret = kDefaultState; int ret = kDefaultState;
{ {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
// default to 3min to avoid temprary network failures. // default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline); DECLARE_int32(rpc_deadline);
DECLARE_int32(rpc_retry_times);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -177,6 +177,7 @@ def __bootstrap__(): ...@@ -177,6 +177,7 @@ def __bootstrap__():
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
#env for rpc #env for rpc
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_retry_times')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler') read_env_flags.append('enable_rpc_profiler')
read_env_flags.append('rpc_send_thread_num') read_env_flags.append('rpc_send_thread_num')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册