From 036a90f125c3bd5f7caa15241413136f0a768f76 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 5 Jun 2018 16:46:42 +0800 Subject: [PATCH] Refine rpc client wait sync (#11132) --- paddle/fluid/operators/detail/grpc_client.cc | 96 ++++++++----------- paddle/fluid/operators/detail/grpc_client.h | 23 ++++- paddle/fluid/operators/detail/grpc_server.cc | 7 +- .../operators/detail/grpc_server_test.cc | 11 +-- paddle/fluid/operators/fetch_barrier_op.cc | 4 +- paddle/fluid/operators/prefetch_op.cc | 2 +- paddle/fluid/operators/recv_op.cc | 2 +- paddle/fluid/operators/send_barrier_op.cc | 4 +- paddle/fluid/operators/send_op.cc | 8 +- paddle/fluid/operators/test_send_nccl_id.cc | 13 ++- 10 files changed, 80 insertions(+), 90 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index da9ca1a0c1d..f4d83e86ecb 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -38,6 +38,25 @@ void RPCClient::Init() { if (rpc_client_.get() == nullptr) { rpc_client_.reset(new RPCClient()); } + rpc_client_->InitEventLoop(); +} + +void RPCClient::InitEventLoop() { + // start the client process thread + // TODO(wuyi): can make this in a threadpool + client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); +} + +RPCClient::~RPCClient() { + Wait(); + cq_.Shutdown(); + { + std::lock_guard guard(chan_mutex_); + for (auto& it : channels_) { + it.second.reset(); + } + } + client_thread_->join(); } bool RPCClient::AsyncSendVariable(const std::string& ep, @@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { req_count_++; } -bool RPCClient::Wait() { - VLOG(3) << "RPCClient begin Wait()" - << " req_count_:" << req_count_; - if (req_count_ <= 0) { - return true; - } - const size_t kReqCnt = req_count_; - bool a[kReqCnt]; - std::vector> waits(req_count_); - std::mutex mu; - - for (int i = 0; i < req_count_; i++) { - waits[i] = framework::AsyncIO([i, &a, &mu, this] { - bool ret = Proceed(); - std::lock_guard l(mu); - a[i] = ret; - }); - } - - for (int i = 0; i < req_count_; i++) { - waits[i].wait(); - } - - int last_req_count = req_count_; - req_count_ = 0; - - for (int i = 0; i < last_req_count; i++) { - if (!a[i]) { - return false; - } - } - - return true; +void RPCClient::Wait() { + std::unique_lock lk(sync_mutex_); + sync_cond_.wait(lk, [this] { return req_count_ == 0; }); } -bool RPCClient::Proceed() { - void* tag = NULL; +void RPCClient::Proceed() { + void* tag = nullptr; bool ok = false; - // request counts. - if (!cq_.Next(&tag, &ok)) { - LOG(ERROR) << "Get meets CompletionQueue error"; - return false; - } - - GPR_ASSERT(ok); - PADDLE_ENFORCE(tag); - - // TODO(gongwb): add more retries. - BaseProcessor* c = static_cast(tag); - if (!c->status_.ok()) { - LOG(ERROR) << "proc param error:" << c->var_h_.String() - << " grpc error:" << c->status_.error_message(); + while (cq_.Next(&tag, &ok)) { + BaseProcessor* c = static_cast(tag); + GPR_ASSERT(ok); + PADDLE_ENFORCE(c); + if (c->status_.ok()) { + c->Process(); + } else { + LOG(ERROR) << "var: " << c->var_h_.String() + << " grpc error:" << c->status_.error_message(); + } delete c; - return false; + { + std::lock_guard lk(sync_mutex_); + req_count_--; + } + sync_cond_.notify_all(); } - - c->Process(); - delete c; - return true; } + std::shared_ptr RPCClient::GetChannel(const std::string& ep) { // TODO(Yancey1989): make grpc client completely thread-safe - std::unique_lock lock(mutex_); + std::lock_guard guard(chan_mutex_); auto it = channels_.find(ep); if (it != channels_.end()) { return it->second; diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 449d5105afb..bb3813efcf4 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -16,15 +16,18 @@ limitations under the License. */ #include -#include // NOLINT +#include // NOLINT +#include // NOLINT #include #include #include #include #include // NOLINT #include +#include // NOLINT #include +#include "grpc++/channel.h" #include "grpc++/generic/generic_stub.h" #include "grpc++/grpc++.h" #include "grpc++/support/byte_buffer.h" @@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor { class RPCClient { public: RPCClient() {} + ~RPCClient(); static RPCClient* GetInstance(); @@ -192,19 +196,28 @@ class RPCClient { void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); - bool Wait(); + void Wait(); + // InitEventLoop should only be called by Init() + void InitEventLoop(); private: - bool Proceed(); + void Proceed(); std::shared_ptr GetChannel(const std::string& ep); // Init is called by GetInstance. static void Init(); private: grpc::CompletionQueue cq_; - std::map> channels_; + std::unordered_map> channels_; + std::unique_ptr client_thread_; + + // mutex for Wait client sync + std::mutex sync_mutex_; + std::condition_variable sync_cond_; std::atomic req_count_{0}; - std::mutex mutex_; + + // mutex for GetChannel thread safety + std::mutex chan_mutex_; static std::unique_ptr rpc_client_; static std::once_flag init_flag_; DISABLE_COPY_AND_ASSIGN(RPCClient); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index e73756d8900..57867aad4d6 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -68,9 +68,7 @@ class RequestSend final : public RequestBase { method_id, &ctx_, request_.get(), &responder_, cq_, cq_, reinterpret_cast(static_cast(req_id))); } - virtual ~RequestSend() {} - std::string GetReqName() override { return request_->Varname(); } void Process() override { @@ -82,7 +80,6 @@ class RequestSend final : public RequestBase { framework::Variable* outvar = nullptr; request_handler_->Handle(varname, scope, invar, &outvar); - status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, reinterpret_cast(static_cast(req_id_))); @@ -125,7 +122,6 @@ class RequestGet final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); } - status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, reinterpret_cast(static_cast(req_id_))); @@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); - - status_ = FINISH; responder_.Finish(reply_, ::grpc::Status::OK, reinterpret_cast(static_cast(req_id_))); + status_ = FINISH; } protected: diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index f97f638701c..22a3a813575 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -113,10 +113,6 @@ void StartServer() { std::thread server_thread( std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); - // FIXME(gongwb): don't use hard time. - sleep(10); - LOG(INFO) << "got nccl id and stop server..."; - g_rpc_service->ShutDown(); server_thread.join(); } @@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) { std::thread server_thread(StartServer); g_rpc_service->WaitServerReady(); - detail::RPCClient client; + detail::RPCClient* client = detail::RPCClient::GetInstance(); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); @@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); - client.Wait(); + client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); + client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); auto ptr = value.mutable_data(place); @@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) { } } + g_rpc_service->ShutDown(); server_thread.join(); LOG(INFO) << "begin reset"; g_rpc_service.reset(nullptr); diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 79ec02f5209..1e2c93335fb 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase { auto rpc_client = detail::RPCClient::GetInstance(); - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); for (auto& ep : eps) { VLOG(3) << "fetch barrier, ep: " << ep; rpc_client->AsyncSendFetchBarrier(ep); } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } }; diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index e0a9b24ac89..167a06e090c 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } }; diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index d8ddb7b4489..49b480948a7 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase { rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } } }; diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index bcd8e81609a..2bc38ff4e3e 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase { VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; // need to wait before sending send_barrier message - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); if (sync_mode) { for (auto& ep : eps) { VLOG(3) << "send barrier, ep: " << ep; rpc_client->AsyncSendBatchBarrier(ep); } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } } }; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index a5150f242ca..a91b1453896 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); if (sync_mode) { for (auto& ep : endpoints) { VLOG(3) << "batch barrier, ep: " << ep; rpc_client->AsyncSendBatchBarrier(ep); } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } if (outs.size() > 0) { @@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase { VLOG(2) << "getting " << outs[i] << " from " << epmap[i]; rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); // tell pservers that current trainer have called fetch for (auto& ep : endpoints) { VLOG(2) << "send fetch barrier, ep: " << ep; rpc_client->AsyncSendFetchBarrier(ep); } - PADDLE_ENFORCE(rpc_client->Wait()); + rpc_client->Wait(); } } }; diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index a845ba2eb03..eb01ac9b907 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -61,7 +61,6 @@ void StartServer() { std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); g_rpc_service->SetCond(detail::kRequestSend); - std::cout << "before WaitFanInOfSend" << std::endl; g_rpc_service->WaitBarrier(detail::kRequestSend); LOG(INFO) << "got nccl id and stop server..."; @@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) { int port = g_rpc_service->GetSelectedPort(); std::string ep = string::Sprintf("127.0.0.1:%d", port); - detail::RPCClient client; - LOG(INFO) << "connect to server" << ep; - client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); - client.Wait(); - client.AsyncSendBatchBarrier(ep); - client.Wait(); + detail::RPCClient* client = detail::RPCClient::GetInstance(); + LOG(INFO) << "connect to server " << ep; + client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + client->Wait(); + client->AsyncSendBatchBarrier(ep); + client->Wait(); server_thread.join(); g_rpc_service.reset(nullptr); -- GitLab