From 2028a8ef6d077df1decdbeaf3b8b9794a6593b2a Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 6 Jun 2018 21:04:27 -0500 Subject: [PATCH] Add rpc_client interface. (#11154) --- paddle/fluid/operators/detail/CMakeLists.txt | 2 +- paddle/fluid/operators/detail/grpc_client.cc | 64 ++++++--------- paddle/fluid/operators/detail/grpc_client.h | 56 ++++++------- .../operators/detail/grpc_server_test.cc | 6 +- paddle/fluid/operators/detail/rpc_client.cc | 26 ++++++ paddle/fluid/operators/detail/rpc_client.h | 82 +++++++++++++++++++ paddle/fluid/operators/fetch_barrier_op.cc | 4 +- paddle/fluid/operators/gen_nccl_id_op.cc | 7 +- paddle/fluid/operators/prefetch_op.cc | 6 +- paddle/fluid/operators/recv_op.cc | 5 +- paddle/fluid/operators/send_barrier_op.cc | 3 +- paddle/fluid/operators/send_op.cc | 7 +- paddle/fluid/operators/send_vars_op.cc | 5 +- paddle/fluid/operators/test_send_nccl_id.cc | 7 +- 14 files changed, 191 insertions(+), 89 deletions(-) create mode 100644 paddle/fluid/operators/detail/rpc_client.cc create mode 100644 paddle/fluid/operators/detail/rpc_client.h diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index cf2053051..c29dc5d7e 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_DISTRIBUTE) grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc - request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor + request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows memory) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index f4d83e86e..fae39418b 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,29 +25,15 @@ namespace paddle { namespace operators { namespace detail { -std::once_flag RPCClient::init_flag_; +void GRPCClient::InitImpl() { InitEventLoop(); } -std::unique_ptr RPCClient::rpc_client_(nullptr); - -RPCClient* RPCClient::GetInstance() { - std::call_once(init_flag_, &RPCClient::Init); - return rpc_client_.get(); -} - -void RPCClient::Init() { - if (rpc_client_.get() == nullptr) { - rpc_client_.reset(new RPCClient()); - } - rpc_client_->InitEventLoop(); -} - -void RPCClient::InitEventLoop() { +void GRPCClient::InitEventLoop() { // start the client process thread // TODO(wuyi): can make this in a threadpool - client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this))); + client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); } -RPCClient::~RPCClient() { +GRPCClient::~GRPCClient() { Wait(); cq_.Shutdown(); { @@ -59,11 +45,10 @@ RPCClient::~RPCClient() { client_thread_->join(); } -bool RPCClient::AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { result->Swap(&tmp); } -bool RPCClient::AsyncGetVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out) { +bool GRPCClient::AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; @@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out) { +bool GRPCClient::AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string in_var_name_val = in_var_name; @@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, return true; } -void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); @@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { req_count_++; } -void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { +void GRPCClient::AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out) { const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); s->Prepare(time_out); @@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { req_count_++; } -void RPCClient::Wait() { +void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); } -void RPCClient::Proceed() { +void GRPCClient::Proceed() { void* tag = nullptr; bool ok = false; @@ -251,7 +237,7 @@ void RPCClient::Proceed() { } } -std::shared_ptr RPCClient::GetChannel(const std::string& ep) { +std::shared_ptr GRPCClient::GetChannel(const std::string& ep) { // TODO(Yancey1989): make grpc client completely thread-safe std::lock_guard guard(chan_mutex_); auto it = channels_.find(ep); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index bb3813efc..8db73f875 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -38,6 +38,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN @@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor { std::unique_ptr stub_; }; -class RPCClient { +class GRPCClient : public RPCClient { public: - RPCClient() {} - ~RPCClient(); + GRPCClient() {} + virtual ~GRPCClient(); - static RPCClient* GetInstance(); + bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncSendVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = 600 * 1000); + bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncGetVariable(const std::string& ep, + bool AsyncPrefetchVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, - const std::string& var_name, - int64_t time_out = 600 * 1000); + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = RPCClient::rpc_time_out) override; - bool AsyncPrefetchVariable(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& in_var_name, - const std::string& out_var_name, - int64_t time_out = 600 * 1000); + void AsyncSendBatchBarrier( + const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out) override; - void AsyncSendBatchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); + void AsyncSendFetchBarrier( + const std::string& ep, + int64_t time_out = RPCClient::rpc_time_out) override; - void AsyncSendFetchBarrier(const std::string& ep, - int64_t time_out = 600 * 1000); + void Wait() override; - void Wait(); + protected: + void InitImpl() override; + + private: // InitEventLoop should only be called by Init() void InitEventLoop(); - private: void Proceed(); + std::shared_ptr GetChannel(const std::string& ep); - // Init is called by GetInstance. - static void Init(); private: grpc::CompletionQueue cq_; @@ -218,9 +218,7 @@ class RPCClient { // mutex for GetChannel thread safety std::mutex chan_mutex_; - static std::unique_ptr rpc_client_; - static std::once_flag init_flag_; - DISABLE_COPY_AND_ASSIGN(RPCClient); + DISABLE_COPY_AND_ASSIGN(GRPCClient); }; } // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 22a3a8135..a1f9ba15e 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_registry.h" @@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) { std::thread server_thread(StartServer); g_rpc_service->WaitServerReady(); - detail::RPCClient* client = detail::RPCClient::GetInstance(); + detail::RPCClient* client = + detail::RPCClient::GetInstance(); int port = g_rpc_service->GetSelectedPort(); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); @@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name); + client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name); client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); diff --git a/paddle/fluid/operators/detail/rpc_client.cc b/paddle/fluid/operators/detail/rpc_client.cc new file mode 100644 index 000000000..9a791403e --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_client.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/rpc_client.h" + +namespace paddle { +namespace operators { +namespace detail { + +std::once_flag RPCClient::init_flag_; +std::unique_ptr RPCClient::rpc_client_(nullptr); + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/rpc_client.h b/paddle/fluid/operators/detail/rpc_client.h new file mode 100644 index 000000000..8e3717f07 --- /dev/null +++ b/paddle/fluid/operators/detail/rpc_client.h @@ -0,0 +1,82 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace operators { +namespace detail { + +class RPCClient { + public: + virtual bool AsyncSendVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncGetVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual bool AsyncPrefetchVar(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& in_var_name, + const std::string& out_var_name, + int64_t time_out = rpc_time_out) = 0; + + virtual void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = rpc_time_out) = 0; + + virtual void AsyncSendFetchBarrier(const std::string& ep, + int64_t time_out = rpc_time_out) = 0; + + virtual void Wait() = 0; + + static constexpr int64_t rpc_time_out = 120 * 1000; + + template + static RPCClient* GetInstance() { + std::call_once(init_flag_, &RPCClient::Init); + return rpc_client_.get(); + } + + // Init is called by GetInstance. + template + static void Init() { + if (rpc_client_.get() == nullptr) { + rpc_client_.reset(new T()); + rpc_client_->InitImpl(); + } + } + + protected: + virtual void InitImpl() {} + + private: + static std::once_flag init_flag_; + static std::unique_ptr rpc_client_; +}; +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 1e2c93335..ad67585e4 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/rpc_client.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); rpc_client->Wait(); diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index 4bce2d322..547de4fa4 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase { std::vector endpoint_list = Attr>("endpoint_list"); - detail::RPCClient client; + detail::RPCClient* client = + detail::RPCClient::GetInstance(); for (auto& ep : endpoint_list) { VLOG(3) << "sending nccl id to " << ep; - client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME); + client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME); } - client.Wait(); + client->Wait(); VLOG(3) << "sending completed..."; } diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 167a06e09..d96359d6b 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " << outs[i] << " back"; - rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i], - outs[i]); + rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 49b480948..1ea1cc458 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } if (sync_mode) { rpc_client->Wait(); diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 2bc38ff4e..511ad7538 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index a91b14538..969757970 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } @@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase { if (outs.size() > 0) { for (size_t i = 0; i < outs.size(); i++) { VLOG(2) << "getting " << outs[i] << " from " << epmap[i]; - rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]); } rpc_client->Wait(); // tell pservers that current trainer have called fetch diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index fe839dab6..564e40461 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto rpc_client = detail::RPCClient::GetInstance(); + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; // TODO(Yancey1989): we need to use an IO threadpool which has // a larger number of threads than the computing threadpool. - rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]); } else { VLOG(3) << "don't send no-initialied variable: " << ins[i]; } diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc index eb01ac9b9..a01a75cbb 100644 --- a/paddle/fluid/operators/test_send_nccl_id.cc +++ b/paddle/fluid/operators/test_send_nccl_id.cc @@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) { int port = g_rpc_service->GetSelectedPort(); std::string ep = string::Sprintf("127.0.0.1:%d", port); - detail::RPCClient* client = detail::RPCClient::GetInstance(); - LOG(INFO) << "connect to server " << ep; - client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); + detail::RPCClient* client = + detail::RPCClient::GetInstance(); + LOG(INFO) << "connect to server" << ep; + client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->Wait(); client->AsyncSendBatchBarrier(ep); client->Wait(); -- GitLab