未验证 提交 0b1c7d83 编写于 作者: G gongweibao 提交者: GitHub

Add brpc serialization support. (#11430)

上级 37c2e245
......@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig()
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to
......
......@@ -14,14 +14,16 @@
INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl)
find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY})
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY})
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
......@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib")
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a"
GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS}
......@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc)
......@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
IF(WITH_TESTING)
ENABLE_TESTING()
#FIXME:(gongwb) Move brpc's gtest dependency.
IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
......@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING)
ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
......@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0"
GIT_REPOSITORY "https://github.com/google/leveldb"
GIT_TAG v1.18
CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
......@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(WITH_NGRAPH)
if(NOT WIN32)
......
......@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc)
ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
......@@ -30,7 +37,7 @@ else()
variable_visitor)
if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc)
ddim selected_rows_functor sendrecvop_rpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
......
......@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(0)
->SendComplete();
auto client =
paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
client->SendComplete();
#endif
}
......
......@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory)
......@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL)
endif()
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)
cc_test(brpc_server_test SRCS rpc_server_test.cc
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
endif()
......@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response) {
void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
// this channel can be used by other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
LOG(FATAL) << "Fail to send SendVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleSendResponse";
}
bool BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "SendRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO(
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage request;
distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
&cntl->request_attachment(), "", false,
trainer_id_);
google::protobuf::Closure* done =
brpc::NewCallback(&HandleSendResponse, cntl, response);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
sendrecv::VariableMessage request;
ch_ctx->stub->SendVariable(cntl, &request, response, done);
});
platform::RecordRPCEvent record_event(method, p_ctx);
ch_ctx->stub->SendVariable(cntl, &request, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(FATAL) << "Fail to get HandleFetchBarrierResponse: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleFetchBarrierResponse";
}
void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) {
sendrecv::VariableMessage* response, VarHandlePtr var_h,
ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText();
LOG(FATAL) << "Fail to GetVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
cls->DecreaseReqCount();
var_h->Finish(false);
return;
}
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
framework::Variable* outvar = nullptr;
int trainer_id;
distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
*var_h->ctx(), var_h->scope(), &outvar,
&trainer_id);
VLOG(4) << "Finish HandleGetResponse";
cls->DecreaseReqCount();
var_h->Finish(true);
}
bool BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, int64_t time_out) {
VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "GetRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO(
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {});
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method, p_ctx);
if (method_name == "GetMonomerVariable") {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
}
bool BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out) {
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "PrefetchRPC";
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(in_var_name_val);
sendrecv::VariableMessage req;
distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
&cntl->request_attachment(), out_var_name_val,
false, 0, table_name_val);
platform::RecordRPCEvent record_event(method, p_ctx);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {});
ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++;
return true;
return var_h;
}
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
req_count_++;
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
time_out);
}
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = "FetchBarrierRPC";
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
platform::RecordRPCEvent record_event(method, nullptr);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->GetVariable(cntl, &req, response, done);
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
void BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
bool BRPCClient::Wait() {
VLOG(9) << "begin to brpcclient wait";
{
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
VLOG(9) << "end to brpcclient wait";
return true;
}
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VLOG(4) << "begin to GetChannel:" << ep;
{
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
VLOG(4) << "end to GetChannel:" << ep;
return it->second;
}
}
......@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.protocol = "baidu_std";
options.connection_type = "pooled";
options.connect_timeout_ms = 100;
// don't use pooled type. the server can't afford that.
options.connection_type = "single";
options.connect_timeout_ms = 1000;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
VLOG(1) << "create " << brpc_channel_num_per_server_
<< " brpc channels to pserver:" << ep;
for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to initialize channel";
......@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_[ep] = q;
}
VLOG(4) << "end to GetChannel:" << ep;
return q;
}
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
for (auto& kv : channels_) {
AsyncSendComplete(kv.first);
}
}
VarHandlePtr BRPCClient::AsyncSendVarMessage(
const std::string& ep, const std::string& method_name,
const sendrecv::VariableMessage& req, int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
platform::RecordRPCEvent record_event(method_name, nullptr);
VarHandlePtr var_h(
new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == "CheckPointNotifyRPC") {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == "GetMonomerBarrier") {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
}
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(message);
return AsyncSendVarMessage(ep, method_name, req, time_out);
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
......@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient() {}
virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier(
const std::string& ep, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override;
VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
private:
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep);
VarHandlePtr AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message, int64_t time_out);
VarHandlePtr AsyncSendVarMessage(const std::string& ep,
const std::string& method_name,
const sendrecv::VariableMessage& req,
int64_t time_out);
friend void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h,
ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx,
BRPCClient* cls);
void DecreaseReqCount() {
if (--req_count_ <= 0) {
sync_cond_.notify_all();
}
}
private:
std::unordered_map<std::string, ChannelQueuePtr> channels_;
......@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0};
static constexpr int brpc_channel_num_per_server_ = 4;
// mutex for GetChannel thread safety
std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
RdmaMemPool& RdmaMemPool::Instance() {
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
return *g_rdma_mem_pool;
}
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
pthread_rwlock_rdlock(&access_);
auto it = pool_.find(varname);
if (it == pool_.end()) {
pthread_rwlock_unlock(&access_);
return nullptr;
}
auto info = it->second;
if (info.data_size != size) {
pthread_rwlock_unlock(&access_);
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size,
info.data_size);
return nullptr;
}
pthread_rwlock_unlock(&access_);
return info.data;
}
void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) {
void* old = Find(varname, data_size);
if (old != nullptr) {
if (data != old) {
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old);
}
VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
return;
}
VarInfo info;
info.data = data;
info.data_size = data_size;
pthread_rwlock_wrlock(&access_);
pool_[varname] = info;
pthread_rwlock_unlock(&access_);
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
LOG(FATAL) << "register " << varname << " data:" << data
<< " data_size:" << data_size << " error";
}
VLOG(4) << "register on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace distributed {
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class RdmaMemPool {
public:
static RdmaMemPool& Instance();
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
void Register(const std::string& varname, void* data, int64_t size);
void* Find(const std::string& varname, int64_t size);
private:
struct VarInfo {
void* data;
int64_t data_size;
VarInfo() : data(nullptr), data_size(0) {}
};
private:
std::unordered_map<std::string, VarInfo> pool_;
pthread_rwlock_t access_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
class IOBufWriter {
public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
}
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
int64_t vlen, bool in_cuda_pinned,
void (*destroy)(void*), void* user_data) {
VLOG(7) << "AppendTCPZeroCopy "
<< " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf->append(v, vlen);
destroy(user_data);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
RdmaMemPool::Instance().Register(
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf->append(v, vlen);
destroy(user_data);
return;
}
#endif
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
#else
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
user_data);
#endif
}
};
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, int trainer_id,
const std::string& table_name) {
std::unique_ptr<TensorPayload> payload;
request->set_varname(name);
request->set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request->set_profile(platform::kEnableProfiler);
} else {
request->set_profile(platform::kDisableProfiler);
}
}
if (!out_varname.empty()) {
request->set_out_varname(out_varname);
}
if (!table_name.empty()) {
request->set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request->set_type(::sendrecv::LOD_TENSOR);
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
} else if (var->IsType<framework::SelectedRows>()) {
request->set_type(::sendrecv::SELECTED_ROWS);
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
PADDLE_ENFORCE_NOT_NULL(payload);
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
#endif
} else {
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
}
}
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
}
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!");
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 564 * 128;
// serialize var to IOBuf
{
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// desrialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
}
void RunTestLodTensor(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 512 * 8 * 4 * 2;
{
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ(msg.varname(), "myvar");
EXPECT_EQ(msg.type(), 0);
EXPECT_EQ(msg.dims()[0], 512);
EXPECT_EQ(msg.dims()[1], 8);
EXPECT_EQ(msg.dims()[2], 4);
EXPECT_EQ(msg.dims()[3], 2);
EXPECT_EQ(msg.lod_level(), 1);
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
}
// deserialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
for (int i = 0; i < tensor_numel; ++i)
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
......@@ -13,84 +13,287 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv {
typedef std::unordered_map<std::string,
paddle::operators::distributed::RequestHandler*>
namespace distributed = paddle::operators::distributed;
typedef std::unordered_map<std::string, distributed::RequestHandler*>
HandlerMap;
class BRPCServiceImpl : public SendRecvService {
public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map)
: request_send_h_(nullptr),
request_get_h_(nullptr),
request_prefetch_h_(nullptr) {
auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
distributed::RPCServer* rpc_server)
: rpc_server_(rpc_server) {
VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
auto it = rpc_call_map.find(distributed::kRequestSend);
if (it != rpc_call_map.end()) {
request_send_h_ = it->second;
send_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestSend)));
}
it = rpc_call_map.find(paddle::operators::distributed::kRequestSend);
it = rpc_call_map.find(distributed::kRequestGet);
if (it != rpc_call_map.end()) {
request_get_h_ = it->second;
get_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGet)));
}
it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch);
it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
prefetch_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestCheckpoint);
if (it != rpc_call_map.end()) {
request_checkpoint_h_ = it->second;
checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
if (it != rpc_call_map.end()) {
request_get_monomer_handler_h_ = it->second;
}
it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
if (it != rpc_call_map.end()) {
request_get_monomer_barrier_handler_h_ = it->second;
}
}
virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
send_threads_->Run(
[=] { _SendVariable(cntl_butil, request, response, done); });
}
void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestSend var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
if (!request_send_h_->sync_mode()) {
local_scope = &request_send_h_->scope()->NewScope();
invar = local_scope->Var(varname);
} else {
invar = local_scope->FindVar(varname);
}
distributed::BRPCVariableResponse resp(request_send_h_->scope(),
request_send_h_->dev_ctx(),
!request_send_h_->sync_mode());
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
request_send_h_->Handle(varname, local_scope, invar, &outvar);
auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
if (!request_send_h_->sync_mode()) {
request_send_h_->scope()->DeleteScope(local_scope);
}
request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
}
void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override {
get_threads_->Run(
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!");
}
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
auto invar = scope->FindVar(varname);
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
response, &cntl->response_attachment(), "",
false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
prefetch_threads_->Run(
[=] { _PrefetchVariable(cntl_butil, request, response, done); });
}
void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// prefetch process...
std::string in_var_name = request->varname();
std::string out_var_name = request->out_varname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< ", out_var_name: " << out_var_name
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
std::string table_name = request->table_name();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = scope->Var(out_var_name);
request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
distributed::SerializeToIOBuf(out_var_name, outvar,
*request_prefetch_h_->dev_ctx(), response,
&cntl->response_attachment(), "", true);
}
void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
checkpoint_notify_threads_->Run(
[=] { _CheckpointNotify(cntl_butil, request, response, done); });
}
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(
request_checkpoint_h_ != nullptr,
"kRequestCheckpointNotify handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
request_checkpoint_h_->dev_ctx());
auto scope = resp.GetMutableLocalScope();
std::string checkpoint_notify = request->varname();
std::string checkpoint_dir = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
}
void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_handler_h_ != nullptr,
"kRequestGetMonomerVariable handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// proc request.
std::string varname = request->varname();
VLOG(3) << "GetMonomerVariable " << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
request->trainer_id());
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
&cntl->response_attachment(), "", false);
}
}
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_barrier_handler_h_ != nullptr,
"RequestGetMonomerBarrier handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
paddle::framework::Scope* scope = nullptr;
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_barrier_handler_h_->Handle(
varname, scope, invar, &outvar, request->trainer_id());
}
private:
paddle::operators::distributed::RequestHandler* request_send_h_;
paddle::operators::distributed::RequestHandler* request_get_h_;
paddle::operators::distributed::RequestHandler* request_prefetch_h_;
distributed::RequestHandler* request_send_h_{nullptr};
distributed::RequestHandler* request_get_h_{nullptr};
distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
};
} // namespace sendrecv
......@@ -100,7 +303,7 @@ namespace distributed {
void AsyncBRPCServer::StartServer() {
// Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_);
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
......@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
}
brpc::ServerOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
namespace pb = ::google::protobuf;
using vr = ::sendrecv::VariableMessage;
int BRPCVariableResponse::Parse(Source* source) {
pb::io::ZeroCopyInputStream* input_stream = source->contents();
pb::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (1) {
unsigned int tag = 0;
if (!input.ReadLittleEndian32(&tag)) {
break;
}
uint64_t num_bytes = 0;
if (!input.ReadLittleEndian64(&num_bytes)) {
break;
}
int field = static_cast<int>(tag);
int ret = field == 0 ? -1 : field;
switch (field) {
case vr::kSerializedFieldNumber: {
if (!ProcSerializedField(field, &input, num_bytes)) {
return ret;
}
break;
}
case vr::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret;
}
break;
}
default: {
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field);
return ret;
}
}
}
return 0;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class BRPCSourceWrapper : public Source {
public:
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return &source_;
}
private:
butil::IOBufAsZeroCopyInputStream source_;
};
class BRPCVariableResponse : public VariableResponse {
public:
BRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~BRPCVariableResponse() {}
// parse attachment from iobuf
int Parse(Source* source) override;
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
BRPCSourceWrapper wrapper(iobuf);
return VariableResponse::Parse(&wrapper, meta);
}
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
......@@ -32,13 +32,6 @@ namespace paddle {
namespace operators {
namespace distributed {
static void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name,
......
......@@ -75,6 +75,10 @@ class RPCServer {
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
}
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h"
......@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory::Copy(cuda_pinned, result->ptr(),
boost::get<platform::CUDAPlace>(tensor.place()),
tensor.data<void>(), copy_size, gpu_dev_ctx.stream());
ctx.Wait();
return TensorPayload(result);
#else
......
......@@ -50,6 +50,13 @@ class TensorPayload final {
size_t memory_size_;
};
inline void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
......
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
......@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch");
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle {
namespace operators {
......
......@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
......
......@@ -81,6 +81,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool IsCompiledWithBrpc() {
#if defined(PADDLE_WITH_BRPC) || defined(PADDLE_WITH_BRPC_RDMA)
return true;
#else
return false;
#endif
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DISTRIBUTE
return true;
......@@ -631,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
......
......@@ -152,6 +152,7 @@ def __bootstrap__():
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册