From d2d6e8fdee56209e191bd6ba1acccda033a405f6 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Mon, 25 Jun 2018 20:21:07 +0800 Subject: [PATCH] cherrypick grpc fixes (#11692) --- cmake/external/grpc.cmake | 6 +++--- .../operators/distributed/grpc_client.cc | 3 ++- .../fluid/operators/distributed/grpc_client.h | 20 +++++++++---------- .../operators/distributed/grpc_server.cc | 20 +++++++++---------- .../fluid/operators/distributed/rpc_client.cc | 4 ++++ .../fluid/operators/distributed/rpc_client.h | 19 +++++++++--------- .../fluid/operators/distributed/rpc_server.cc | 7 ++++--- paddle/fluid/operators/listen_and_serv_op.cc | 2 -- 8 files changed, 43 insertions(+), 38 deletions(-) diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index ffdf91a354b..85f40585da2 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -40,12 +40,12 @@ ExternalProject_Add( # NOTE(wuyi): # this package is generated by following steps: # 1. git clone -b v1.8.x https://github.com/grpc/grpc.git - # 2. submodule update --init + # 2. git submodule update --init # 3. keep only zlib, cares, protobuf, boringssl under "third_party", # checkout and clean other dirs under third_party # 4. remove .git, and package the directory. - URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz" - URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0" + URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz" + URL_MD5 "1f268a2aff6759839dccd256adcc91cf" PREFIX ${GRPC_SOURCES_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 65d63784c14..b7e63392be4 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -258,14 +258,15 @@ void GRPCClient::Proceed() { } std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { - // TODO(Yancey1989): make grpc client completely thread-safe std::lock_guard guard(chan_mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { return it->second; } + // Channel configurations: grpc::ChannelArguments args; + args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000); args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index a6efa7dfd11..2edf0c9f8df 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -72,6 +72,7 @@ class BaseProcessor { virtual void Prepare(const VarHandle& var_info, int64_t time_out) { context_.reset(new grpc::ClientContext()); var_h_ = var_info; + context_->set_wait_for_ready(true); std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); @@ -81,6 +82,7 @@ class BaseProcessor { virtual void Prepare(int64_t time_out) { context_.reset(new grpc::ClientContext()); + context_->set_wait_for_ready(true); std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); @@ -172,26 +174,24 @@ class GRPCClient : public RPCClient { bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = RPCClient::rpc_time_out) override; + int64_t time_out = FLAGS_grpc_deadline) override; bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = RPCClient::rpc_time_out) override; + int64_t time_out = FLAGS_grpc_deadline) override; bool AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, - int64_t time_out = RPCClient::rpc_time_out) override; + int64_t time_out = FLAGS_grpc_deadline) override; - void AsyncSendBatchBarrier( - const std::string& ep, - int64_t time_out = RPCClient::rpc_time_out) override; + void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = FLAGS_grpc_deadline) override; - void AsyncSendFetchBarrier( - const std::string& ep, - int64_t time_out = RPCClient::rpc_time_out) override; + void AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out = FLAGS_grpc_deadline) override; void Wait() override; @@ -207,7 +207,7 @@ class GRPCClient : public RPCClient { void Proceed(); void AsyncSendComplete(const std::string& ep, - int64_t time_out = RPCClient::rpc_time_out); + int64_t time_out = FLAGS_grpc_deadline); std::shared_ptr GetChannel(const std::string& ep); diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index 707f665a29d..0b51160c6f2 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -84,7 +84,7 @@ class RequestSend final : public RequestBase { void Process() override { std::string varname = GetReqName(); - VLOG(3) << "RequestSend var_name:" << varname; + VLOG(4) << "RequestSend var_name:" << varname; auto scope = request_->GetMutableLocalScope(); auto invar = request_->GetVar(); @@ -119,7 +119,7 @@ class RequestGet final : public RequestBase { void Process() override { // proc request. std::string varname = request_.varname(); - VLOG(3) << "RequestGet " << varname; + VLOG(4) << "RequestGet " << varname; auto scope = request_handler_->scope(); auto invar = scope->FindVar(varname); @@ -165,7 +165,7 @@ class RequestPrefetch final : public RequestBase { // prefetch process... std::string in_var_name = request_->Varname(); std::string out_var_name = request_->OutVarname(); - VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name + VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name << " out_var_name: " << out_var_name; auto scope = request_->GetMutableLocalScope(); @@ -188,10 +188,10 @@ class RequestPrefetch final : public RequestBase { }; void AsyncGRPCServer::WaitServerReady() { - VLOG(3) << "AsyncGRPCServer is wait server ready"; + VLOG(4) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); - VLOG(3) << "AsyncGRPCServer WaitSeverReady"; + VLOG(4) << "AsyncGRPCServer WaitSeverReady"; } void AsyncGRPCServer::StartServer() { @@ -230,7 +230,7 @@ void AsyncGRPCServer::StartServer() { for (int i = 0; i < threadnum; i++) { rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind( &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f))); - VLOG(3) << t.first << " creates threads!"; + VLOG(4) << t.first << " creates threads!"; } } @@ -247,7 +247,7 @@ void AsyncGRPCServer::StartServer() { auto& threads = t.second; for (size_t i = 0; i < threads.size(); ++i) { threads[i]->join(); - VLOG(3) << t.first << " threads ends!"; + VLOG(4) << t.first << " threads ends!"; } } } @@ -255,7 +255,7 @@ void AsyncGRPCServer::StartServer() { void AsyncGRPCServer::ShutdownQueue() { for (auto& t : rpc_cq_) { t.second->Shutdown(); - VLOG(3) << t.first << " shutdown!"; + VLOG(4) << t.first << " queue shutdown!"; } } @@ -264,7 +264,7 @@ void AsyncGRPCServer::ShutDownImpl() { is_shut_down_ = true; ShutdownQueue(); - VLOG(3) << "server_ shutdown!"; + VLOG(4) << "server_ shutdown!"; server_->Shutdown(); } @@ -272,7 +272,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { - VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; + VLOG(4) << "shutdown, do not TryToRegisterNewSendOne"; return; } diff --git a/paddle/fluid/operators/distributed/rpc_client.cc b/paddle/fluid/operators/distributed/rpc_client.cc index c71edf977c1..2cf87faaab3 100644 --- a/paddle/fluid/operators/distributed/rpc_client.cc +++ b/paddle/fluid/operators/distributed/rpc_client.cc @@ -13,6 +13,10 @@ // limitations under the License. #include "paddle/fluid/operators/distributed/rpc_client.h" +#include "gflags/gflags.h" + +// default to 3min to avoid temprary network failures. +DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc"); namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 72fa6d94088..db437a7f1ec 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -15,11 +15,14 @@ #pragma once #include +#include "gflags/gflags.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +DECLARE_int32(grpc_deadline); + namespace paddle { namespace operators { namespace distributed { @@ -32,26 +35,26 @@ class RPCClient { const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = rpc_time_out) = 0; + int64_t time_out = FLAGS_grpc_deadline) = 0; virtual bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, - int64_t time_out = rpc_time_out) = 0; + int64_t time_out = FLAGS_grpc_deadline) = 0; virtual bool AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& in_var_name, const std::string& out_var_name, - int64_t time_out = rpc_time_out) = 0; + int64_t time_out = FLAGS_grpc_deadline) = 0; - virtual void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = rpc_time_out) = 0; + virtual void AsyncSendBatchBarrier( + const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0; - virtual void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = rpc_time_out) = 0; + virtual void AsyncSendFetchBarrier( + const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0; // SendComplete tells all the server that current trainer have no more data // to train, so that the pserver can reduce it's barrier count, and continue @@ -60,8 +63,6 @@ class RPCClient { virtual void Wait() = 0; - static constexpr int64_t rpc_time_out = 120 * 1000; - template static RPCClient* GetInstance() { std::call_once(init_flag_, &RPCClient::Init); diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index fa0cb71b305..c0520e248d4 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); }); - VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name]; + VLOG(3) << "batch_barrier_: " << rpc_name << " " + << barrier_counter_[rpc_name]; } void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { - VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; + VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; int b = 0; std::unique_lock lock(mutex_); b = ++barrier_counter_[rpc_name]; @@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) { } void RPCServer::WaitCond(const std::string& rpc_name) { - VLOG(3) << "RPCServer WaitCond " << rpc_name; + VLOG(4) << "RPCServer WaitCond " << rpc_name; int cond = 0; { std::unique_lock lock(mutex_); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index f840064ecac..6086c31722c 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop( void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, framework::ProgramDesc *program) const { - VLOG(3) << "RunAsyncLoop in"; // grad name to block id std::unordered_map grad_to_block_id; std::unordered_map id_to_grad; @@ -203,7 +202,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); - VLOG(3) << "RunAsyncLoop into while"; while (true) { if (rpc_service_->IsExit()) { LOG(INFO) << "get exit!rpc_processor break!"; -- GitLab