未验证 提交 2028a8ef 编写于 作者: G gongweibao 提交者: GitHub

Add rpc_client interface. (#11154)

上级 ca2d6d3c
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -25,29 +25,15 @@ namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
void GRPCClient::InitImpl() { InitEventLoop(); }
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
rpc_client_->InitEventLoop();
}
void RPCClient::InitEventLoop() {
void GRPCClient::InitEventLoop() {
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_thread_.reset(new std::thread(std::bind(&RPCClient::Proceed, this)));
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
RPCClient::~RPCClient() {
GRPCClient::~GRPCClient() {
Wait();
cq_.Shutdown();
{
......@@ -59,11 +45,10 @@ RPCClient::~RPCClient() {
client_thread_->join();
}
bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
bool GRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
......@@ -113,11 +98,10 @@ void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
result->Swap(&tmp);
}
bool RPCClient::AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
bool GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
......@@ -155,12 +139,12 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}
bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
bool GRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
......@@ -198,7 +182,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
return true;
}
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
void GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
......@@ -211,7 +196,8 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
req_count_++;
}
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
......@@ -223,12 +209,12 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_++;
}
void RPCClient::Wait() {
void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
void RPCClient::Proceed() {
void GRPCClient::Proceed() {
void* tag = nullptr;
bool ok = false;
......@@ -251,7 +237,7 @@ void RPCClient::Proceed() {
}
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
......
......@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
......@@ -164,47 +165,46 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class RPCClient {
class GRPCClient : public RPCClient {
public:
RPCClient() {}
~RPCClient();
GRPCClient() {}
virtual ~GRPCClient();
static RPCClient* GetInstance();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncGetVariable(const std::string& ep,
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncPrefetchVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = 600 * 1000);
void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);
void Wait() override;
void Wait();
protected:
void InitImpl() override;
private:
// InitEventLoop should only be called by Init()
void InitEventLoop();
private:
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private:
grpc::CompletionQueue cq_;
......@@ -218,9 +218,7 @@ class RPCClient {
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
DISABLE_COPY_AND_ASSIGN(GRPCClient);
};
} // namespace detail
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -123,7 +124,8 @@ TEST(PREFETCH, CPU) {
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient* client = detail::RPCClient::GetInstance();
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
......@@ -137,7 +139,7 @@ TEST(PREFETCH, CPU) {
std::string in_var_name("ids");
std::string out_var_name("out");
client->AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/rpc_client.h"
namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCClient {
public:
virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
return rpc_client_.get();
}
// Init is called by GetInstance.
template <typename T>
static void Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T());
rpc_client_->InitImpl();
}
}
protected:
virtual void InitImpl() {}
private:
static std::once_flag init_flag_;
static std::unique_ptr<RPCClient> rpc_client_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -43,7 +44,8 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
rpc_client->Wait();
......
......@@ -61,12 +61,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient client;
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
}
client.Wait();
client->Wait();
VLOG(3) << "sending completed...";
}
......
......@@ -41,14 +41,14 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get "
<< outs[i] << " back";
rpc_client->AsyncPrefetchVariable(epmap[i], ctx, scope, ins[i],
outs[i]);
rpc_client->AsyncPrefetchVar(epmap[i], ctx, scope, ins[i], outs[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
......
......@@ -44,11 +44,12 @@ class RecvOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
rpc_client->Wait();
......
......@@ -44,7 +44,8 @@ class SendBarrierOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
......@@ -49,12 +49,13 @@ class SendOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
......@@ -72,7 +73,7 @@ class SendOp : public framework::OperatorBase {
if (outs.size() > 0) {
for (size_t i = 0; i < outs.size(); i++) {
VLOG(2) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]);
}
rpc_client->Wait();
// tell pservers that current trainer have called fetch
......
......@@ -45,14 +45,15 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
......
......@@ -87,9 +87,10 @@ TEST(SendNcclId, GrpcServer) {
int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client = detail::RPCClient::GetInstance();
LOG(INFO) << "connect to server " << ep;
client->AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client->Wait();
client->AsyncSendBatchBarrier(ep);
client->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册