未验证 提交 d2d6e8fd 编写于 作者: W Wu Yi 提交者: GitHub

cherrypick grpc fixes (#11692)

上级 57780401
...@@ -40,12 +40,12 @@ ExternalProject_Add( ...@@ -40,12 +40,12 @@ ExternalProject_Add(
# NOTE(wuyi): # NOTE(wuyi):
# this package is generated by following steps: # this package is generated by following steps:
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git # 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
# 2. submodule update --init # 2. git submodule update --init
# 3. keep only zlib, cares, protobuf, boringssl under "third_party", # 3. keep only zlib, cares, protobuf, boringssl under "third_party",
# checkout and clean other dirs under third_party # checkout and clean other dirs under third_party
# 4. remove .git, and package the directory. # 4. remove .git, and package the directory.
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz" URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz"
URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0" URL_MD5 "1f268a2aff6759839dccd256adcc91cf"
PREFIX ${GRPC_SOURCES_DIR} PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -258,14 +258,15 @@ void GRPCClient::Proceed() { ...@@ -258,14 +258,15 @@ void GRPCClient::Proceed() {
} }
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
return it->second; return it->second;
} }
// Channel configurations:
grpc::ChannelArguments args; grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max()); args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
...@@ -72,6 +72,7 @@ class BaseProcessor { ...@@ -72,6 +72,7 @@ class BaseProcessor {
virtual void Prepare(const VarHandle& var_info, int64_t time_out) { virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
var_h_ = var_info; var_h_ = var_info;
context_->set_wait_for_ready(true);
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
...@@ -81,6 +82,7 @@ class BaseProcessor { ...@@ -81,6 +82,7 @@ class BaseProcessor {
virtual void Prepare(int64_t time_out) { virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext()); context_.reset(new grpc::ClientContext());
context_->set_wait_for_ready(true);
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
...@@ -172,26 +174,24 @@ class GRPCClient : public RPCClient { ...@@ -172,26 +174,24 @@ class GRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_grpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_grpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_grpc_deadline) override;
void AsyncSendBatchBarrier( void AsyncSendBatchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier( void AsyncSendFetchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) override;
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override; void Wait() override;
...@@ -207,7 +207,7 @@ class GRPCClient : public RPCClient { ...@@ -207,7 +207,7 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep, void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out); int64_t time_out = FLAGS_grpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
......
...@@ -84,7 +84,7 @@ class RequestSend final : public RequestBase { ...@@ -84,7 +84,7 @@ class RequestSend final : public RequestBase {
void Process() override { void Process() override {
std::string varname = GetReqName(); std::string varname = GetReqName();
VLOG(3) << "RequestSend var_name:" << varname; VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar(); auto invar = request_->GetVar();
...@@ -119,7 +119,7 @@ class RequestGet final : public RequestBase { ...@@ -119,7 +119,7 @@ class RequestGet final : public RequestBase {
void Process() override { void Process() override {
// proc request. // proc request.
std::string varname = request_.varname(); std::string varname = request_.varname();
VLOG(3) << "RequestGet " << varname; VLOG(4) << "RequestGet " << varname;
auto scope = request_handler_->scope(); auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname); auto invar = scope->FindVar(varname);
...@@ -165,7 +165,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -165,7 +165,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process... // prefetch process...
std::string in_var_name = request_->Varname(); std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
...@@ -188,10 +188,10 @@ class RequestPrefetch final : public RequestBase { ...@@ -188,10 +188,10 @@ class RequestPrefetch final : public RequestBase {
}; };
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready"; VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady"; VLOG(4) << "AsyncGRPCServer WaitSeverReady";
} }
void AsyncGRPCServer::StartServer() { void AsyncGRPCServer::StartServer() {
...@@ -230,7 +230,7 @@ void AsyncGRPCServer::StartServer() { ...@@ -230,7 +230,7 @@ void AsyncGRPCServer::StartServer() {
for (int i = 0; i < threadnum; i++) { for (int i = 0; i < threadnum; i++) {
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind( rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f))); &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
VLOG(3) << t.first << " creates threads!"; VLOG(4) << t.first << " creates threads!";
} }
} }
...@@ -247,7 +247,7 @@ void AsyncGRPCServer::StartServer() { ...@@ -247,7 +247,7 @@ void AsyncGRPCServer::StartServer() {
auto& threads = t.second; auto& threads = t.second;
for (size_t i = 0; i < threads.size(); ++i) { for (size_t i = 0; i < threads.size(); ++i) {
threads[i]->join(); threads[i]->join();
VLOG(3) << t.first << " threads ends!"; VLOG(4) << t.first << " threads ends!";
} }
} }
} }
...@@ -255,7 +255,7 @@ void AsyncGRPCServer::StartServer() { ...@@ -255,7 +255,7 @@ void AsyncGRPCServer::StartServer() {
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
for (auto& t : rpc_cq_) { for (auto& t : rpc_cq_) {
t.second->Shutdown(); t.second->Shutdown();
VLOG(3) << t.first << " shutdown!"; VLOG(4) << t.first << " queue shutdown!";
} }
} }
...@@ -264,7 +264,7 @@ void AsyncGRPCServer::ShutDownImpl() { ...@@ -264,7 +264,7 @@ void AsyncGRPCServer::ShutDownImpl() {
is_shut_down_ = true; is_shut_down_ = true;
ShutdownQueue(); ShutdownQueue();
VLOG(3) << "server_ shutdown!"; VLOG(4) << "server_ shutdown!";
server_->Shutdown(); server_->Shutdown();
} }
...@@ -272,7 +272,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -272,7 +272,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) { int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "gflags/gflags.h"
// default to 3min to avoid temprary network failures.
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
#pragma once #pragma once
#include <string> #include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
DECLARE_int32(grpc_deadline);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
...@@ -32,26 +35,26 @@ class RPCClient { ...@@ -32,26 +35,26 @@ class RPCClient {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep, virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep, virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0; int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep, virtual void AsyncSendBatchBarrier(
int64_t time_out = rpc_time_out) = 0; const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep, virtual void AsyncSendFetchBarrier(
int64_t time_out = rpc_time_out) = 0; const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data // SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue // to train, so that the pserver can reduce it's barrier count, and continue
...@@ -60,8 +63,6 @@ class RPCClient { ...@@ -60,8 +63,6 @@ class RPCClient {
virtual void Wait() = 0; virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
template <typename T> template <typename T>
static RPCClient* GetInstance() { static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>); std::call_once(init_flag_, &RPCClient::Init<T>);
......
...@@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) { ...@@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load()); return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name]; VLOG(3) << "batch_barrier_: " << rpc_name << " "
<< barrier_counter_[rpc_name];
} }
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0; int b = 0;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name]; b = ++barrier_counter_[rpc_name];
...@@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) { ...@@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
} }
void RPCServer::WaitCond(const std::string& rpc_name) { void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name; VLOG(4) << "RPCServer WaitCond " << rpc_name;
int cond = 0; int cond = 0;
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
......
...@@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop( ...@@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop(
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const { framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
...@@ -203,7 +202,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -203,7 +202,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "RunAsyncLoop into while";
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!"; LOG(INFO) << "get exit!rpc_processor break!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册