未验证 提交 1bc285a5 编写于 作者: 1 123malin 提交者: GitHub

add retry function to try to solve grpc error code 14 (#19661)

* rpc retry for asycsend/get/prefetch

* test=develop, change retry vlog level to 3

* test=develop, set default grpc_retry_times is 3
上级 5eb381a3
......@@ -73,36 +73,53 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
int retry_times_ = 0;
while (true) {
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
// stub context
s->response_call_back_ = nullptr;
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
platform::RecordRPCEvent record_event(method);
// stub context
s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
platform::RecordRPCEvent record_event(method);
if (UNLIKELY(platform::IsProfileEnabled())) {
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
});
req_count_++;
return h;
return h;
}
}
void ProcGetResponse(const VarHandle& var_h,
......@@ -169,42 +186,57 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
int retry_times_ = 0;
while (true) {
GetProcessor* s = new GetProcessor(ch);
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method,
p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
req.set_table_name(table_name_val);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
// stub context
s->response_call_back_ = ProcGetResponse;
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
platform::RecordRPCEvent record_event(method);
// stub context
s->response_call_back_ = ProcGetResponse;
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
platform::RecordRPCEvent record_event(method);
auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
});
req_count_++;
return h;
return h;
}
}
VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
......@@ -221,41 +253,55 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = kPrefetchRPC;
int retry_times_ = 0;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
while (true) {
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
p_ctx, s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
out_var_name_val, 0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method);
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
});
req_count_++;
return h;
return h;
}
}
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
......@@ -420,6 +466,14 @@ void GRPCClient::Proceed() {
ok_ = false;
}
c->Finish(false);
} else if (c->status_.error_code() == grpc::StatusCode::UNAVAILABLE) {
VLOG(3) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details()
<< " should retry!";
c->GetVarHandlePtr()->should_retry = true;
c->Finish(false);
} else {
LOG(FATAL) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code()
......
......@@ -85,6 +85,8 @@ class VarHandle {
virtual ~VarHandle() {}
public:
bool should_retry = false;
bool Wait() {
int ret = kDefaultState;
{
......
......@@ -17,6 +17,7 @@
// default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc");
namespace paddle {
namespace operators {
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_deadline);
DECLARE_int32(rpc_retry_times);
namespace paddle {
namespace operators {
......
......@@ -177,6 +177,7 @@ def __bootstrap__():
if core.is_compiled_with_dist():
#env for rpc
read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_retry_times')
read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler')
read_env_flags.append('rpc_send_thread_num')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册