未验证 提交 d9de6b86 编写于 作者: G gongweibao 提交者: GitHub

Add brpc surpport. (#11263)

上级 bfa3fd6f
...@@ -55,12 +55,13 @@ option(WITH_FLUID_ONLY "Compile PaddlePaddle fluid only" OFF) ...@@ -55,12 +55,13 @@ option(WITH_FLUID_ONLY "Compile PaddlePaddle fluid only" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with grpc distributed support" OFF) option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF) option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF) option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -147,7 +148,16 @@ include(external/any) # download libn::any ...@@ -147,7 +148,16 @@ include(external/any) # download libn::any
include(external/eigen) # download eigen3 include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/grpc)
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
include(external/grpc)
else()
include(external/leveldb)
include(external/brpc)
endif()
endif()
include(external/snappy) # download snappy include(external/snappy) # download snappy
include(external/snappystream) include(external/snappystream)
include(external/threadpool) include(external/threadpool)
......
...@@ -166,3 +166,7 @@ if(WITH_GOLANG) ...@@ -166,3 +166,7 @@ if(WITH_GOLANG)
endif() endif()
endif(WITH_GOLANG) endif(WITH_GOLANG)
if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
SET(BRPC_INCLUDE_DIR "${BRPC_INSTALL_DIR}/include" CACHE PATH "brpc include directory." FORCE)
SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc library." FORCE)
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/brpc/brpc"
GIT_TAG "6d153dd7ff00f960ae6895c9c5fff0ce9f07aff2"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
${EXTERNAL_OPTIONAL_ARGS}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${BRPC_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${BRPC_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
ADD_DEPENDENCIES(extern_brpc protobuf leveldb gflags glog gtest snappy)
ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
LIST(APPEND external_project_dependencies brpc)
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(LEVELDB_SOURCES_DIR ${THIRD_PARTY_PATH}/leveldb)
SET(LEVELDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/leveldb)
SET(LEVELDB_INCLUDE_DIR "${LEVELDB_INSTALL_DIR}/include" CACHE PATH "leveldb include directory." FORCE)
SET(LEVELDB_LIBRARIES "${LEVELDB_INSTALL_DIR}/lib/libleveldb.a" CACHE FILEPATH "leveldb library." FORCE)
INCLUDE_DIRECTORIES(${LEVELDB_INCLUDE_DIR})
ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
&& cp ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/libleveldb.a ${LEVELDB_LIBRARIES}
&& cp -r ${LEVELDB_SOURCES_DIR}/src/extern_leveldb/include ${LEVELDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES(extern_leveldb snappy)
ADD_LIBRARY(leveldb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET leveldb PROPERTY IMPORTED_LOCATION ${LEVELDB_LIBRARIES})
ADD_DEPENDENCIES(leveldb extern_leveldb)
LIST(APPEND external_project_dependencies leveldb)
...@@ -610,3 +610,21 @@ function(grpc_library TARGET_NAME) ...@@ -610,3 +610,21 @@ function(grpc_library TARGET_NAME)
COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}") cc_library("${TARGET_NAME}" SRCS "${grpc_library_SRCS}" DEPS "${TARGET_NAME}_grpc" "${TARGET_NAME}_proto" "${grpc_library_DEPS}")
endfunction() endfunction()
function(brpc_library TARGET_NAME)
set(oneValueArgs PROTO)
set(multiValueArgs SRCS DEPS)
set(options "")
cmake_parse_arguments(brpc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
message(STATUS "generating brpc ${brpc_library_PROTO}")
get_filename_component(ABS_PROTO ${brpc_library_PROTO} ABSOLUTE)
get_filename_component(PROTO_WE ${brpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}")
cc_library("${TARGET_NAME}_proto" SRCS "${brpc_proto_srcs}")
cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}")
endfunction()
...@@ -64,7 +64,8 @@ class TRTConvertValidation { ...@@ -64,7 +64,8 @@ class TRTConvertValidation {
TRTConvertValidation(int batch_size, TRTConvertValidation(int batch_size,
const std::unordered_set<std::string>& parameters, const std::unordered_set<std::string>& parameters,
framework::Scope& scope, int workspace_size = 1 << 10) framework::Scope& scope, // NOLINT
int workspace_size = 1 << 10)
: parameters_(parameters), scope_(scope) { : parameters_(parameters), scope_(scope) {
// create engine. // create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
......
...@@ -186,8 +186,14 @@ endif() ...@@ -186,8 +186,14 @@ endif()
add_subdirectory(detail) add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib)
endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS}) op_library(prefetch_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(prefetch_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
...@@ -207,7 +213,11 @@ if(WITH_DISTRIBUTE) ...@@ -207,7 +213,11 @@ if(WITH_DISTRIBUTE)
if(WITH_GPU) if(WITH_GPU)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL) cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL)
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc) if(WITH_GRPC)
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc)
else()
op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_brpc)
endif()
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
......
if(WITH_DISTRIBUTE) if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory) selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc SERIAL) cares zlib protobuf sendrecvop_grpc SERIAL)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL) proto_desc lookup_table_op SERIAL)
return()
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so)
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
brpc protobuf leveldb gflags glog
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
}
bool BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
framework::AsyncIO(
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
google::protobuf::Closure* done =
brpc::NewCallback(&HandleSendResponse, cntl, response);
sendrecv::VariableMessage request;
ch_ctx->stub->SendVariable(cntl, &request, response, done);
});
req_count_++;
return true;
}
void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
}
bool BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
framework::AsyncIO(
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {});
req_count_++;
return true;
}
bool BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {});
req_count_++;
return true;
}
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
}
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
}
void BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
{
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
}
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "pooled";
options.connect_timeout_ms = 100;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(ERROR) << "Fail to initialize channel";
return nullptr;
}
c->stub.reset(new sendrecv::SendRecvService_Stub(
static_cast<google::protobuf::RpcChannel*>(&c->channel)));
q->Push(c);
}
{
std::lock_guard<std::mutex> guard(chan_mutex_);
channels_[ep] = q;
}
return q;
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <time.h>
#include <chrono> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
namespace detail {
struct ChannelContext {
brpc::Channel channel;
std::shared_ptr<sendrecv::SendRecvService_Stub> stub;
};
typedef std::shared_ptr<ChannelContext> ChannelContextPtr;
typedef std::shared_ptr<framework::BlockingQueue<ChannelContextPtr>>
ChannelQueuePtr;
class BRPCClient : public RPCClient {
public:
BRPCClient() {}
virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override;
private:
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
private:
std::unordered_map<std::string, ChannelQueuePtr> channels_;
// mutex for Wait client sync
std::mutex sync_mutex_;
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient);
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
namespace sendrecv {
typedef std::unordered_map<std::string,
paddle::operators::detail::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map)
: request_send_h_(nullptr),
request_get_h_(nullptr),
request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestSend);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
}
it = rpc_call_map.find(paddle::operators::detail::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
}
}
virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
std::string varname = request->varname();
if (!request_send_h_->sync_mode()) {
local_scope = &request_send_h_->scope()->NewScope();
invar = local_scope->Var(varname);
} else {
invar = local_scope->FindVar(varname);
}
request_send_h_->Handle(varname, local_scope, invar, &outvar);
if (!request_send_h_->sync_mode()) {
request_send_h_->scope()->DeleteScope(local_scope);
}
}
void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!");
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!");
}
private:
paddle::operators::detail::RequestHandler* request_send_h_;
paddle::operators::detail::RequestHandler* request_get_h_;
paddle::operators::detail::RequestHandler* request_prefetch_h_;
};
} // namespace sendrecv
namespace paddle {
namespace operators {
namespace detail {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// use brpc::SERVER_OWNS_SERVICE.
if (server_.AddService(&service_impl, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(FATAL) << "Fail to add service";
return;
}
brpc::ServerOptions options;
options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to start EchoServer" << bind_address_;
return;
}
butil::EndPoint ep = server_.listen_address();
selected_port_ = ep.port;
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1;
}
condition_ready_.notify_all();
server_.Join();
}
void AsyncBRPCServer::ShutDownImpl() { server_.Stop(1000); }
void AsyncBRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
}
}; // namespace detail
}; // namespace operators
}; // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace detail {
class AsyncBRPCServer final : public RPCServer {
public:
explicit AsyncBRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
virtual ~AsyncBRPCServer() {}
void StartServer() override;
void WaitServerReady() override;
private:
void ShutDownImpl() override;
brpc::Server server_;
static constexpr int idle_timeout_s_ = -1;
static constexpr int max_concurrency_ = 0;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient
#else
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient
#endif
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,6 +37,10 @@ constexpr char kRequestSend[] = "RequestSend"; ...@@ -38,6 +37,10 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
class RPCServer; class RPCServer;
class RequestHandler { class RequestHandler {
......
...@@ -16,15 +16,12 @@ ...@@ -16,15 +16,12 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h" #include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h" #include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -26,6 +26,8 @@ namespace detail { ...@@ -26,6 +26,8 @@ namespace detail {
class RPCClient { class RPCClient {
public: public:
RPCClient() {}
virtual ~RPCClient() {}
virtual bool AsyncSendVar(const std::string& ep, virtual bool AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
......
...@@ -17,15 +17,14 @@ limitations under the License. */ ...@@ -17,15 +17,14 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
...@@ -33,7 +32,7 @@ namespace detail = paddle::operators::detail; ...@@ -33,7 +32,7 @@ namespace detail = paddle::operators::detail;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service; std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
...@@ -112,20 +111,19 @@ void StartServer() { ...@@ -112,20 +111,19 @@ void StartServer() {
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true)); g_req_handler.reset(new detail::RequestPrefetchHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
syntax = "proto3"; syntax = "proto3";
package sendrecv; package sendrecv;
// option cc_generic_services = true;
service SendRecvService { service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors. // For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor // Send and recv only one tensor
......
...@@ -32,16 +32,6 @@ namespace paddle { ...@@ -32,16 +32,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
typedef void (*DestroyCallback)(void*); typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
......
...@@ -19,9 +19,7 @@ limitations under the License. */ ...@@ -19,9 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -45,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
rpc_client->Wait(); rpc_client->Wait();
......
...@@ -21,8 +21,7 @@ limitations under the License. */ ...@@ -21,8 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -61,8 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -61,8 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
detail::RPCClient* client = detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
detail::RPCClient::GetInstance<detail::GRPCClient>();
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
...@@ -78,9 +77,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -78,9 +77,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::RequestSendHandler rpc_h(true); detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1); std::unique_ptr<detail::RPCServer> rpc_service(
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h); new RPCSERVER_T(endpoint, 1));
rpc_h.SetRPCServer(&rpc_service);
rpc_service->RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
...@@ -90,12 +91,13 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -90,12 +91,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service)); std::bind(&detail::RPCServer::StartServer, rpc_service.get()));
rpc_service.SetCond(detail::kRequestSend);
rpc_service->SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service.WaitBarrier(detail::kRequestSend); rpc_service->WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
server_thread.join(); server_thread.join();
} }
......
...@@ -19,7 +19,8 @@ limitations under the License. */ ...@@ -19,7 +19,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -89,6 +90,12 @@ void ListenAndServOp::SavePort() const { ...@@ -89,6 +90,12 @@ void ListenAndServOp::SavePort() const {
rpc_service_->SavePort(); rpc_service_->SavePort();
} }
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
...@@ -127,7 +134,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -127,7 +134,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = program->Block(1).Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) { if (blkid != static_cast<size_t>(prefetch_block->ID())) {
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
...@@ -141,7 +148,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -141,7 +148,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
} }
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
rpc_service_->SetCond(detail::kRequestGet); rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet);
...@@ -235,8 +242,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -235,8 +242,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint; << ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode)); request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
namespace paddle { namespace paddle {
...@@ -42,7 +42,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -42,7 +42,7 @@ class PrefetchOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -19,8 +19,7 @@ limitations under the License. */ ...@@ -19,8 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +44,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -45,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -45,7 +45,7 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -45,7 +45,7 @@ class SendBarrierOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode; VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -46,7 +46,7 @@ class SendOp : public framework::OperatorBase { ...@@ -46,7 +46,7 @@ class SendOp : public framework::OperatorBase {
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
detail::RPCClient* rpc_client = detail::RPCClient* rpc_client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
......
...@@ -20,8 +20,7 @@ limitations under the License. */ ...@@ -20,8 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h" #include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -29,6 +28,10 @@ limitations under the License. */ ...@@ -29,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/send_recv_util.h"
#endif
USE_NO_KERNEL_OP(listen_and_serv); USE_NO_KERNEL_OP(listen_and_serv);
namespace f = paddle::framework; namespace f = paddle::framework;
...@@ -37,7 +40,7 @@ namespace m = paddle::operators::math; ...@@ -37,7 +40,7 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service; std::unique_ptr<detail::RPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler; std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer() { void StartServer() {
...@@ -58,7 +61,7 @@ void StartServer() { ...@@ -58,7 +61,7 @@ void StartServer() {
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get())); std::bind(&detail::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend); g_rpc_service->SetCond(detail::kRequestSend);
g_rpc_service->WaitBarrier(detail::kRequestSend); g_rpc_service->WaitBarrier(detail::kRequestSend);
...@@ -68,9 +71,9 @@ void StartServer() { ...@@ -68,9 +71,9 @@ void StartServer() {
server_thread.join(); server_thread.join();
} }
TEST(SendNcclId, GrpcServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset(new detail::RequestSendHandler(true)); g_req_handler.reset(new detail::RequestSendHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
...@@ -87,8 +90,9 @@ TEST(SendNcclId, GrpcServer) { ...@@ -87,8 +90,9 @@ TEST(SendNcclId, GrpcServer) {
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient* client =
detail::RPCClient::GetInstance<detail::GRPCClient>(); detail::RPCClient* client = detail::RPCClient::GetInstance<RPCCLIENT_T>();
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client->Wait(); client->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册