From 1bc285a53a979ac07767701210c6afbaeadf711d Mon Sep 17 00:00:00 2001 From: 123malin Date: Wed, 18 Sep 2019 21:34:11 +0800 Subject: [PATCH] add retry function to try to solve grpc error code 14 (#19661) * rpc retry for asycsend/get/prefetch * test=develop, change retry vlog level to 3 * test=develop, set default grpc_retry_times is 3 --- .../operators/distributed/grpc/grpc_client.cc | 188 +++++++++++------- .../operators/distributed/request_handler.h | 2 + .../fluid/operators/distributed/rpc_client.cc | 1 + .../fluid/operators/distributed/rpc_client.h | 1 + python/paddle/fluid/__init__.py | 1 + 5 files changed, 126 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index d06d4b63b6..053fe202fe 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -73,36 +73,53 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, const std::string var_name_val = var_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - SendProcessor* s = new SendProcessor(ch); const std::string method = kSendRPC; - VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { - auto* var = p_scope->FindVar(var_name_val); + int retry_times_ = 0; + + while (true) { + SendProcessor* s = new SendProcessor(ch); + VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - ::grpc::ByteBuffer req; - SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); + framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { + auto* var = p_scope->FindVar(var_name_val); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + ::grpc::ByteBuffer req; + SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); - // stub context - s->response_call_back_ = nullptr; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - platform::RecordRPCEvent record_event(method); + // stub context + s->response_call_back_ = nullptr; - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + platform::RecordRPCEvent record_event(method); - if (UNLIKELY(platform::IsProfileEnabled())) { + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); + req_count_++; + + if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { h->Wait(); + if (h->should_retry) { + VLOG(3) << "rpc call failed, retry times " << retry_times_; + retry_times_++; + std::random_device rd; + std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); + continue; + } } - }); - req_count_++; - return h; + return h; + } } void ProcGetResponse(const VarHandle& var_h, @@ -169,42 +186,57 @@ VarHandlePtr GRPCClient::_AsyncGetVar( const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - GetProcessor* s = new GetProcessor(ch); - VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); - s->Prepare(h, time_out); + int retry_times_ = 0; + + while (true) { + GetProcessor* s = new GetProcessor(ch); - framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, method, - p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_out_varname(out_varname_val); - req.set_trainer_id(trainer_id_); - req.set_table_name(table_name_val); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); + VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, + method, p_ctx, h, rpc_path, this] { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(out_varname_val); + req.set_trainer_id(trainer_id_); + req.set_table_name(table_name_val); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - // stub context - s->response_call_back_ = ProcGetResponse; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - platform::RecordRPCEvent record_event(method); + // stub context + s->response_call_back_ = ProcGetResponse; - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + platform::RecordRPCEvent record_event(method); + + auto call = + s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); + req_count_++; - if (UNLIKELY(platform::IsProfileEnabled())) { + if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { h->Wait(); + if (h->should_retry) { + VLOG(3) << "rpc call failed, retry times " << retry_times_; + retry_times_++; + std::random_device rd; + std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); + continue; + } } - }); - - req_count_++; - return h; + return h; + } } VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, @@ -221,41 +253,55 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const std::string table_name_val = table_name; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); - GetProcessor* s = new GetProcessor(ch); const std::string method = kPrefetchRPC; + int retry_times_ = 0; - VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); - s->Prepare(h, time_out); + while (true) { + GetProcessor* s = new GetProcessor(ch); + VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); - framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - s, method, h, table_name_val, this] { - auto* var = p_scope->FindVar(in_var_name_val); + framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, + p_ctx, s, method, h, table_name_val, this] { + auto* var = p_scope->FindVar(in_var_name_val); - ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, - 0, table_name_val); + ::grpc::ByteBuffer req; + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, + out_var_name_val, 0, table_name_val); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method); + platform::RecordRPCEvent record_event(method); - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, static_cast(s)); + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, static_cast(s)); - if (UNLIKELY(platform::IsProfileEnabled())) { + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); + req_count_++; + + if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { h->Wait(); + if (h->should_retry) { + VLOG(3) << "rpc call failed, retry times " << retry_times_; + retry_times_++; + std::random_device rd; + std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); + continue; + } } - }); - req_count_++; - return h; + return h; + } } VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, @@ -420,6 +466,14 @@ void GRPCClient::Proceed() { ok_ = false; } c->Finish(false); + } else if (c->status_.error_code() == grpc::StatusCode::UNAVAILABLE) { + VLOG(3) << c->GetVarHandlePtr()->String() + << " meets grpc error, error_code:" << c->status_.error_code() + << " error_message:" << c->status_.error_message() + << " error_details:" << c->status_.error_details() + << " should retry!"; + c->GetVarHandlePtr()->should_retry = true; + c->Finish(false); } else { LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error, error_code:" << c->status_.error_code() diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index de8f301846..22083d92ed 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -85,6 +85,8 @@ class VarHandle { virtual ~VarHandle() {} public: + bool should_retry = false; + bool Wait() { int ret = kDefaultState; { diff --git a/paddle/fluid/operators/distributed/rpc_client.cc b/paddle/fluid/operators/distributed/rpc_client.cc index 390e9af0f3..57ce54870d 100644 --- a/paddle/fluid/operators/distributed/rpc_client.cc +++ b/paddle/fluid/operators/distributed/rpc_client.cc @@ -17,6 +17,7 @@ // default to 3min to avoid temprary network failures. DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); +DEFINE_int32(rpc_retry_times, 3, "retry times for rpc"); namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index d4be2c28fd..d0b971e0cb 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -25,6 +25,7 @@ #include "paddle/fluid/operators/distributed/request_handler.h" DECLARE_int32(rpc_deadline); +DECLARE_int32(rpc_retry_times); namespace paddle { namespace operators { diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f2d139b2e6..5817bbf9a1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -177,6 +177,7 @@ def __bootstrap__(): if core.is_compiled_with_dist(): #env for rpc read_env_flags.append('rpc_deadline') + read_env_flags.append('rpc_retry_times') read_env_flags.append('rpc_server_profile_path') read_env_flags.append('enable_rpc_profiler') read_env_flags.append('rpc_send_thread_num') -- GitLab