未验证 提交 0b1c7d83 编写于 作者: G gongweibao 提交者: GitHub

Add brpc serialization support. (#11430)

上级 37c2e245
...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): ...@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = distribute_transpiler.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var config.slice_var_up = not args.no_split_var
config.min_block_size = 1048576
t = distribute_transpiler.DistributeTranspiler(config=config) t = distribute_transpiler.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id, trainer_id,
# NOTE: *MUST* use train_prog, for we are using with guard to # NOTE: *MUST* use train_prog, for we are using with guard to
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
find_library(SSL_LIBRARY NAMES ssl) find_package(OpenSSL REQUIRED)
message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY})
message(STATUS "crypto:" ${OPENSSL_CRYPTO_LIBRARY})
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${SSL_LIBRARY}) SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY})
find_library(CRYPTO_LIBRARY NAMES crypto)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL) ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${CRYPTO_LIBRARY}) SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY})
SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc) SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc)
SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc) SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/brpc)
...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr ...@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${BRPC_INCLUDE_DIR})
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args # Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib") set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/leveldb|${THIRD_PARTY_PATH}/install/snappy|${THIRD_PARTY_PATH}/install/gtest|${THIRD_PARTY_PATH}/install/protobuf|${THIRD_PARTY_PATH}/install/zlib|${THIRD_PARTY_PATH}/install/glog")
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF # If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add( ExternalProject_Add(
extern_brpc extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY "https://github.com/gongweibao/brpc" GIT_REPOSITORY "https://github.com/gongweibao/brpc"
GIT_TAG "7dc04defad1fd4173aae170c3fcbde131b65155a" GIT_TAG "e9b67ec1b7458f2af5fae76451afe1e27e01b4b4"
PREFIX ${BRPC_SOURCES_DIR} PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -50,7 +53,7 @@ ExternalProject_Add( ...@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_PREFIX_PATH=${prefix_path} -DCMAKE_PREFIX_PATH=${prefix_path}
-DBRPC_WITH_GLOG=ON -DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON -DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=${WITH_BRPC_RDMA} -DBRPC_WITH_RDMA=${WITH_BRPC_RDMA}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL) ...@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES}) SET_PROPERTY(TARGET brpc PROPERTY IMPORTED_LOCATION ${BRPC_LIBRARIES})
ADD_DEPENDENCIES(brpc extern_brpc) ADD_DEPENDENCIES(brpc extern_brpc)
add_definitions(-DBRPC_WITH_GLOG)
LIST(APPEND external_project_dependencies brpc) LIST(APPEND external_project_dependencies brpc)
...@@ -12,8 +12,12 @@ ...@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
IF(WITH_TESTING) #FIXME:(gongwb) Move brpc's gtest dependency.
ENABLE_TESTING() IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
IF(WITH_TESTING)
ENABLE_TESTING()
ENDIF(WITH_TESTING)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest)
...@@ -76,4 +80,4 @@ IF(WITH_TESTING) ...@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES(gtest_main extern_gtest) ADD_DEPENDENCIES(gtest_main extern_gtest)
LIST(APPEND external_project_dependencies gtest gtest_main) LIST(APPEND external_project_dependencies gtest gtest_main)
ENDIF(WITH_TESTING) ENDIF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC))
...@@ -24,8 +24,8 @@ ExternalProject_Add( ...@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb extern_leveldb
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LEVELDB_SOURCES_DIR} PREFIX ${LEVELDB_SOURCES_DIR}
URL "https://github.com/google/leveldb/archive/v1.18.tar.gz" GIT_REPOSITORY "https://github.com/google/leveldb"
URL_MD5 "73770de34a2a5ab34498d2e05b2b7fa0" GIT_TAG v1.18
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a BUILD_COMMAND CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} libleveldb.a
INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/ INSTALL_COMMAND mkdir -p ${LEVELDB_INSTALL_DIR}/lib/
......
...@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) ...@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if(WITH_NGRAPH) if(WITH_NGRAPH)
if(NOT WIN32) if(NOT WIN32)
......
...@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc ...@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(NOT WITH_GRPC)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endif()
endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc) ddim dynload_cuda selected_rows_functor sendrecvop_rpc)
else() else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor) ddim dynload_cuda selected_rows_functor)
...@@ -30,7 +37,7 @@ else() ...@@ -30,7 +37,7 @@ else()
variable_visitor) variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc) ddim selected_rows_functor sendrecvop_rpc)
else() else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor) ddim selected_rows_functor)
......
...@@ -157,9 +157,9 @@ void Executor::Close() { ...@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id, // TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0. // except 0.
::paddle::operators::distributed::RPCClient::GetInstance< auto client =
::paddle::operators::distributed::GRPCClient>(0) paddle::operators::distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
->SendComplete(); client->SendComplete();
#endif #endif
} }
......
...@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O ...@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory) DEPS lod_tensor selected_rows_functor memory)
...@@ -20,36 +20,43 @@ if(WITH_GRPC) ...@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(grpc_serde_test SRCS grpc_serde_test.cc cc_test(grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL) selected_rows_functor scope math_function SERIAL)
endif() endif()
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
else() else()
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc set_source_files_properties(brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc brpc_library(sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc collective_client.cc collective_server.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows memory) DEPS lod_tensor selected_rows memory)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy) set(brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib)
cc_test(brpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${brpc_test_depends} SERIAL) DEPS ${brpc_test_depends} SERIAL)
cc_test(brpc_serde_test SRCS brpc_serde_test.cc cc_test(brpc_serde_test SRCS brpc_serde_test.cc
DEPS ${brpc_test_depends} SERIAL) DEPS ${brpc_test_depends} SERIAL)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS ${brpc_test_depends} selected_rows_functor scope math_function SERIAL)
endif()
endif() endif()
...@@ -14,135 +14,316 @@ ...@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h" #include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
DEFINE_int32(brpc_channel_num, 24,
"Number of channels to send requests connected to one server");
DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds"); DEFINE_int32(timeout_ms, 30000, "RPC timeout in milliseconds");
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)"); DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
BRPCClient::~BRPCClient() { Wait(); } BRPCClient::~BRPCClient() { Wait(); }
void HandleSendResponse(brpc::Controller* cntl, void HandleSendResponse(brpc::Controller* cntl, sendrecv::VoidMessage* response,
sendrecv::VoidMessage* response) { VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning. // std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl); std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VoidMessage> response_guard(response); std::unique_ptr<sendrecv::VoidMessage> response_guard(response);
// this channel can be used by other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) { if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText(); LOG(FATAL) << "Fail to send SendVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return; return;
} }
LOG(INFO) << "Received response from " << cntl->remote_side() var_h->Finish(true);
<< " latency=" << cntl->latency_us() << "us"; cls->DecreaseReqCount();
VLOG(4) << "HandleSendResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleSendResponse";
} }
bool BRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "SendRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO( auto* var = p_scope->FindVar(var_name_val);
[var_name_val, p_ctx, ep_val, p_scope, time_out, ch_ptr, this] { sendrecv::VariableMessage request;
auto ch_ctx = ch_ptr->Pop(); distributed::SerializeToIOBuf(var_name_val, var, *p_ctx, &request,
brpc::Controller* cntl = new brpc::Controller(); &cntl->request_attachment(), "", false,
sendrecv::VoidMessage* response = new sendrecv::VoidMessage(); trainer_id_);
cntl->set_timeout_ms(time_out);
google::protobuf::Closure* done = google::protobuf::Closure* done = brpc::NewCallback(
brpc::NewCallback(&HandleSendResponse, cntl, response); &HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
sendrecv::VariableMessage request; platform::RecordRPCEvent record_event(method, p_ctx);
ch_ctx->stub->SendVariable(cntl, &request, response, done);
}); ch_ctx->stub->SendVariable(cntl, &request, response, done);
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
} }
void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) {
LOG(FATAL) << "Fail to get HandleFetchBarrierResponse: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
var_h->Finish(false);
cls->DecreaseReqCount();
return;
}
var_h->Finish(true);
cls->DecreaseReqCount();
VLOG(4) << "HandleFetchBarrierResponse from: " << cntl->remote_side()
<< ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
VLOG(4) << "Finish HandleFetchBarrierResponse";
}
void HandleGetResponse(brpc::Controller* cntl, void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response) { sendrecv::VariableMessage* response, VarHandlePtr var_h,
ChannelQueuePtr ch_ptr, ChannelContextPtr ch_ctx,
BRPCClient* cls) {
// std::unique_ptr makes sure cntl/response will be deleted before returning. // std::unique_ptr makes sure cntl/response will be deleted before returning.
std::unique_ptr<brpc::Controller> cntl_guard(cntl); std::unique_ptr<brpc::Controller> cntl_guard(cntl);
std::unique_ptr<sendrecv::VariableMessage> response_guard(response); std::unique_ptr<sendrecv::VariableMessage> response_guard(response);
// this channel can be used other now.
ch_ptr->Push(ch_ctx);
if (cntl->Failed()) { if (cntl->Failed()) {
LOG(WARNING) << "Fail to send EchoRequest, " << cntl->ErrorText(); LOG(FATAL) << "Fail to GetVar: " << var_h->name()
<< ", error text: " << cntl->ErrorText();
cls->DecreaseReqCount();
var_h->Finish(false);
return; return;
} }
LOG(INFO) << "Received response from " << cntl->remote_side()
<< " latency=" << cntl->latency_us() << "us";
// framework::Variable* outvar = nullptr; VLOG(4) << "HandleGetResponse from: " << cntl->remote_side()
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar); << ", varname: " << var_h->name()
<< ", latency: " << cntl->latency_us() << "us";
framework::Variable* outvar = nullptr;
int trainer_id;
distributed::DeserializeFromIOBuf(*response, cntl->response_attachment(),
*var_h->ctx(), var_h->scope(), &outvar,
&trainer_id);
VLOG(4) << "Finish HandleGetResponse";
cls->DecreaseReqCount();
var_h->Finish(true);
} }
bool BRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, int64_t time_out) { const std::string& var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "GetRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
framework::AsyncIO( sendrecv::VariableMessage req;
[var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {}); req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
platform::RecordRPCEvent record_event(method, p_ctx);
if (method_name == "GetMonomerVariable") {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
} }
bool BRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& var_name,
const std::string& out_var_name, int64_t time_out) {
int64_t time_out) { return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch_ptr = GetChannel(ep_val);
const std::string method = "PrefetchRPC";
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
auto* var = p_scope->FindVar(in_var_name_val);
sendrecv::VariableMessage req;
distributed::SerializeToIOBuf(in_var_name_val, var, *p_ctx, &req,
&cntl->request_attachment(), out_var_name_val,
false, 0, table_name_val);
platform::RecordRPCEvent record_event(method, p_ctx);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleGetResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, ch_ctx->stub->PrefetchVariable(cntl, &req, response, done);
time_out, ch, this] {});
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
});
req_count_++; req_count_++;
return true; return var_h;
} }
void BRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
req_count_++; return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
time_out);
} }
void BRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VariableMessage* response = new sendrecv::VariableMessage();
cntl->set_timeout_ms(time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = "FetchBarrierRPC";
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
platform::RecordRPCEvent record_event(method, nullptr);
google::protobuf::Closure* done = brpc::NewCallback(
&HandleFetchBarrierResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
ch_ctx->stub->GetVariable(cntl, &req, response, done);
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
} }
void BRPCClient::Wait() { bool BRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); VLOG(9) << "begin to brpcclient wait";
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
}
VLOG(9) << "end to brpcclient wait";
return true;
} }
ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VLOG(4) << "begin to GetChannel:" << ep;
{ {
std::lock_guard<std::mutex> guard(chan_mutex_); std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
VLOG(4) << "end to GetChannel:" << ep;
return it->second; return it->second;
} }
} }
...@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>()); ChannelQueuePtr q(new framework::BlockingQueue<ChannelContextPtr>());
brpc::ChannelOptions options; brpc::ChannelOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.connection_type = "pooled"; // don't use pooled type. the server can't afford that.
options.connect_timeout_ms = 100; options.connection_type = "single";
options.connect_timeout_ms = 1000;
options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/; options.timeout_ms = FLAGS_timeout_ms /*milliseconds*/;
options.max_retry = FLAGS_max_retry; options.max_retry = FLAGS_max_retry;
for (int i = 0; i < FLAGS_brpc_channel_num; ++i) {
VLOG(1) << "create " << brpc_channel_num_per_server_
<< " brpc channels to pserver:" << ep;
for (int i = 0; i < brpc_channel_num_per_server_; ++i) {
std::shared_ptr<ChannelContext> c(new ChannelContext()); std::shared_ptr<ChannelContext> c(new ChannelContext());
if (c->channel.Init(ep.c_str(), &options) != 0) { if (c->channel.Init(ep.c_str(), &options) != 0) {
LOG(FATAL) << "Fail to initialize channel"; LOG(FATAL) << "Fail to initialize channel";
...@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { ...@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_[ep] = q; channels_[ep] = q;
} }
VLOG(4) << "end to GetChannel:" << ep;
return q; return q;
} }
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
for (auto& kv : channels_) {
AsyncSendComplete(kv.first);
}
}
VarHandlePtr BRPCClient::AsyncSendVarMessage(
const std::string& ep, const std::string& method_name,
const sendrecv::VariableMessage& req, int64_t time_out) {
auto ch_ptr = GetChannel(ep);
auto ch_ctx = ch_ptr->Pop();
brpc::Controller* cntl = new brpc::Controller();
sendrecv::VoidMessage* response = new sendrecv::VoidMessage();
cntl->set_timeout_ms(time_out);
platform::RecordRPCEvent record_event(method_name, nullptr);
VarHandlePtr var_h(
new VarHandle(ep, method_name, req.varname(), nullptr, nullptr));
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == "CheckPointNotifyRPC") {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == "GetMonomerBarrier") {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
}
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
var_h->Wait();
}
return var_h;
}
VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(message);
return AsyncSendVarMessage(ep, method_name, req, time_out);
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -31,6 +31,8 @@ limitations under the License. */ ...@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient { ...@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient() {} BRPCClient() {}
virtual ~BRPCClient(); virtual ~BRPCClient();
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncSendVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
int64_t time_out = FLAGS_rpc_deadline) override; const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, VarHandlePtr AsyncGetVar(const std::string& ep,
const framework::Scope& scope, const std::string& var_name, const platform::DeviceContext& ctx,
int64_t time_out = FLAGS_rpc_deadline) override; const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, VarHandlePtr AsyncGetMonomerBarrier(
const platform::DeviceContext& ctx, const std::string& ep, const std::string& var_name,
const framework::Scope& scope, int64_t time_out = FLAGS_rpc_deadline) override;
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr AsyncGetMonomerVariable(
int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr AsyncPrefetchVar(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override; const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; VarHandlePtr AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
void SendComplete() override;
private: private:
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
void Proceed(); void Proceed();
ChannelQueuePtr GetChannel(const std::string& ep); ChannelQueuePtr GetChannel(const std::string& ep);
VarHandlePtr AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncSendMessage(const std::string& ep,
const std::string& method_name,
const std::string& message, int64_t time_out);
VarHandlePtr AsyncSendVarMessage(const std::string& ep,
const std::string& method_name,
const sendrecv::VariableMessage& req,
int64_t time_out);
friend void HandleSendResponse(brpc::Controller* cntl,
sendrecv::VoidMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleGetResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h, ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx, BRPCClient* cls);
friend void HandleFetchBarrierResponse(brpc::Controller* cntl,
sendrecv::VariableMessage* response,
VarHandlePtr var_h,
ChannelQueuePtr ch_ptr,
ChannelContextPtr ch_ctx,
BRPCClient* cls);
void DecreaseReqCount() {
if (--req_count_ <= 0) {
sync_cond_.notify_all();
}
}
private: private:
std::unordered_map<std::string, ChannelQueuePtr> channels_; std::unordered_map<std::string, ChannelQueuePtr> channels_;
...@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient { ...@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std::condition_variable sync_cond_; std::condition_variable sync_cond_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
static constexpr int brpc_channel_num_per_server_ = 4;
// mutex for GetChannel thread safety // mutex for GetChannel thread safety
std::mutex chan_mutex_; std::mutex chan_mutex_;
DISABLE_COPY_AND_ASSIGN(BRPCClient); DISABLE_COPY_AND_ASSIGN(BRPCClient);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
RdmaMemPool& RdmaMemPool::Instance() {
static RdmaMemPool* g_rdma_mem_pool = new RdmaMemPool();
return *g_rdma_mem_pool;
}
void* RdmaMemPool::Find(const std::string& varname, int64_t size) {
pthread_rwlock_rdlock(&access_);
auto it = pool_.find(varname);
if (it == pool_.end()) {
pthread_rwlock_unlock(&access_);
return nullptr;
}
auto info = it->second;
if (info.data_size != size) {
pthread_rwlock_unlock(&access_);
PADDLE_ENFORCE(false, "var:%s size:%ld != %ld", varname, size,
info.data_size);
return nullptr;
}
pthread_rwlock_unlock(&access_);
return info.data;
}
void RdmaMemPool::Register(const std::string& varname, void* data,
int64_t data_size) {
void* old = Find(varname, data_size);
if (old != nullptr) {
if (data != old) {
PADDLE_ENFORCE(false, "var:%s data:%ld != %ld", varname, data, old);
}
VLOG(7) << "Find on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
return;
}
VarInfo info;
info.data = data;
info.data_size = data_size;
pthread_rwlock_wrlock(&access_);
pool_[varname] = info;
pthread_rwlock_unlock(&access_);
if (brpc::rdma::RegisterMemoryForRdma(data, data_size)) {
LOG(FATAL) << "register " << varname << " data:" << data
<< " data_size:" << data_size << " error";
}
VLOG(4) << "register on rdma:" << varname << " data:" << data
<< " data_size:" << data_size;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace distributed {
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class RdmaMemPool {
public:
static RdmaMemPool& Instance();
RdmaMemPool() : access_(PTHREAD_RWLOCK_INITIALIZER) {}
virtual ~RdmaMemPool() { pthread_rwlock_destroy(&access_); }
void Register(const std::string& varname, void* data, int64_t size);
void* Find(const std::string& varname, int64_t size);
private:
struct VarInfo {
void* data;
int64_t data_size;
VarInfo() : data(nullptr), data_size(0) {}
};
private:
std::unordered_map<std::string, VarInfo> pool_;
pthread_rwlock_t access_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
class IOBufWriter {
public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
}
static void AppendTCPZeroCopy(butil::IOBuf* iobuf, int k, const char* v,
int64_t vlen, bool in_cuda_pinned,
void (*destroy)(void*), void* user_data) {
VLOG(7) << "AppendTCPZeroCopy "
<< " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf->append(v, vlen);
destroy(user_data);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static void AppendRdmaZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
VLOG(7) << "AppendRdmaZeroCopy varname:" << varname << " k:" << k
<< " data:" << static_cast<void*>(const_cast<char*>(v))
<< " data_size:" << vlen << " in_cuda_pinned:" << in_cuda_pinned;
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
RdmaMemPool::Instance().Register(
varname, static_cast<void*>(const_cast<char*>(v)), vlen);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf->append(v, vlen);
destroy(user_data);
return;
}
#endif
static void AppendZeroCopy(const std::string varname, butil::IOBuf* iobuf,
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
#else
IOBufWriter::AppendTCPZeroCopy(iobuf, k, v, vlen, in_cuda_pinned, destroy,
user_data);
#endif
}
};
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, int trainer_id,
const std::string& table_name) {
std::unique_ptr<TensorPayload> payload;
request->set_varname(name);
request->set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request->set_profile(platform::kEnableProfiler);
} else {
request->set_profile(platform::kDisableProfiler);
}
}
if (!out_varname.empty()) {
request->set_out_varname(out_varname);
}
if (!table_name.empty()) {
request->set_table_name(table_name);
}
if (var->IsType<framework::LoDTensor>()) {
request->set_type(::sendrecv::LOD_TENSOR);
payload.reset(new TensorPayload(GetTensorPayload(var, ctx, request)));
} else if (var->IsType<framework::SelectedRows>()) {
request->set_type(::sendrecv::SELECTED_ROWS);
payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request)));
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
PADDLE_ENFORCE_NOT_NULL(payload);
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
true, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
#endif
} else {
IOBufWriter::AppendZeroCopy(
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size(),
false, SerializeDestroyCallback, static_cast<void*>(payload.get()));
payload.release();
}
}
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
}
void DeserializeFromIOBuf(const ::sendrecv::VariableMessage& meta,
const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
operators::distributed::BRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(iobuf, meta) == 0, "parse iobuf to tensor error!");
*var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToIOBuf(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf, const std::string& out_varname,
bool var_is_not_stable, const int trainer_id = 0,
const std::string& table_name = std::string());
void DeserializeFromIOBuf(const VarMsg& meta, const butil::IOBuf& iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 564 * 128;
// serialize var to IOBuf
{
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
for (int i = 0; i < 564; ++i) rows->push_back(i);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// desrialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto* slr2 = var2->GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
}
EXPECT_EQ(slr2->height(), 1000);
}
}
void RunTestLodTensor(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
butil::IOBuf iobuf;
sendrecv::VariableMessage msg;
int tensor_numel = 512 * 8 * 4 * 2;
{
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
operators::distributed::SerializeToIOBuf("myvar", &var, ctx, &msg, &iobuf,
"", false);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ(msg.varname(), "myvar");
EXPECT_EQ(msg.type(), 0);
EXPECT_EQ(msg.dims()[0], 512);
EXPECT_EQ(msg.dims()[1], 8);
EXPECT_EQ(msg.dims()[2], 4);
EXPECT_EQ(msg.dims()[3], 2);
EXPECT_EQ(msg.lod_level(), 1);
EXPECT_EQ(msg.lod(0).lod_data(0), 1);
EXPECT_EQ(msg.lod(0).lod_data(1), 3);
EXPECT_EQ(msg.lod(0).lod_data(2), 8);
}
// deserialize
{
framework::Scope scope;
scope.Var("myvar");
operators::distributed::BRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(iobuf, msg), 0);
framework::Variable* var2 = resp.GetVar();
auto tensor2 = var2->Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
for (int i = 0; i < tensor_numel; ++i)
EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}
}
TEST(LodTensor, Run) {
platform::CPUPlace place;
RunTestLodTensor(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
#endif
}
TEST(SelectedRows, Run) {
platform::CPUPlace place;
RunSerdeTestSelectedRows(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
...@@ -13,84 +13,287 @@ ...@@ -13,84 +13,287 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h" #include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
namespace sendrecv { namespace sendrecv {
typedef std::unordered_map<std::string, namespace distributed = paddle::operators::distributed;
paddle::operators::distributed::RequestHandler*>
typedef std::unordered_map<std::string, distributed::RequestHandler*>
HandlerMap; HandlerMap;
class BRPCServiceImpl : public SendRecvService { class BRPCServiceImpl : public SendRecvService {
public: public:
explicit BRPCServiceImpl(const HandlerMap& rpc_call_map) explicit BRPCServiceImpl(const HandlerMap& rpc_call_map,
: request_send_h_(nullptr), distributed::RPCServer* rpc_server)
request_get_h_(nullptr), : rpc_server_(rpc_server) {
request_prefetch_h_(nullptr) { VLOG(3) << "BRPCServiceImpl size: " << rpc_call_map.size();
auto it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); auto it = rpc_call_map.find(distributed::kRequestSend);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_send_h_ = it->second; request_send_h_ = it->second;
send_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestSend)));
} }
it = rpc_call_map.find(paddle::operators::distributed::kRequestSend); it = rpc_call_map.find(distributed::kRequestGet);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_get_h_ = it->second; request_get_h_ = it->second;
get_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGet)));
} }
it = rpc_call_map.find(paddle::operators::distributed::kRequestPrefetch); it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) { if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second; request_prefetch_h_ = it->second;
prefetch_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestCheckpoint);
if (it != rpc_call_map.end()) {
request_checkpoint_h_ = it->second;
checkpoint_notify_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestPrefetch)));
}
it = rpc_call_map.find(distributed::kRequestGetMonomerVariable);
if (it != rpc_call_map.end()) {
request_get_monomer_handler_h_ = it->second;
}
it = rpc_call_map.find(distributed::kRequestGetMonomerBarrier);
if (it != rpc_call_map.end()) {
request_get_monomer_barrier_handler_h_ = it->second;
} }
} }
virtual ~BRPCServiceImpl() {} virtual ~BRPCServiceImpl() {}
void SendVariable(google::protobuf::RpcController* cntl_butil, void SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response, const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
send_threads_->Run(
[=] { _SendVariable(cntl_butil, request, response, done); });
}
void _SendVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_send_h_ != nullptr, PADDLE_ENFORCE(request_send_h_ != nullptr,
"RequestSend handler should be registed first!"); "RequestSend handler should be registed first!");
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
paddle::framework::Scope* local_scope = request_send_h_->scope();
paddle::framework::Variable* outvar = nullptr;
paddle::framework::Variable* invar = nullptr;
std::string varname = request->varname(); std::string varname = request->varname();
VLOG(3) << "RequestSend var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
if (!request_send_h_->sync_mode()) { distributed::BRPCVariableResponse resp(request_send_h_->scope(),
local_scope = &request_send_h_->scope()->NewScope(); request_send_h_->dev_ctx(),
invar = local_scope->Var(varname); !request_send_h_->sync_mode());
} else { PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
invar = local_scope->FindVar(varname); "parse iobuf to tensor error!");
}
request_send_h_->Handle(varname, local_scope, invar, &outvar); auto scope = resp.GetMutableLocalScope();
auto invar = resp.GetVar();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
if (!request_send_h_->sync_mode()) { request_send_h_->Handle(varname, scope, invar, &outvar, trainer_id);
request_send_h_->scope()->DeleteScope(local_scope);
}
} }
void GetVariable(google::protobuf::RpcController* cntl_butil, void GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response, const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
get_threads_->Run(
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_get_h_ != nullptr, PADDLE_ENFORCE(request_get_h_ != nullptr,
"RequestGet handler should be registed first!"); "RequestGet handler should be registed first!");
}
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
auto invar = scope->FindVar(varname);
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
response, &cntl->response_attachment(), "",
false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil, void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, const VariableMessage* request,
VariableMessage* response, VariableMessage* response,
google::protobuf::Closure* done) override { google::protobuf::Closure* done) override {
prefetch_threads_->Run(
[=] { _PrefetchVariable(cntl_butil, request, response, done); });
}
void _PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_prefetch_h_ != nullptr, PADDLE_ENFORCE(request_prefetch_h_ != nullptr,
"kRequestPrefetch handler should be registed first!"); "kRequestPrefetch handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// prefetch process...
std::string in_var_name = request->varname();
std::string out_var_name = request->out_varname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
<< ", out_var_name: " << out_var_name
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
distributed::BRPCVariableResponse resp(
request_prefetch_h_->scope(), request_prefetch_h_->dev_ctx(), true);
PADDLE_ENFORCE(resp.Parse(cntl->request_attachment(), *request) == 0,
"parse iobuf to tensor error!");
auto scope = resp.GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
std::string table_name = request->table_name();
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = scope->Var(out_var_name);
request_prefetch_h_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
distributed::SerializeToIOBuf(out_var_name, outvar,
*request_prefetch_h_->dev_ctx(), response,
&cntl->response_attachment(), "", true);
}
void CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
checkpoint_notify_threads_->Run(
[=] { _CheckpointNotify(cntl_butil, request, response, done); });
}
void _CheckpointNotify(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(
request_checkpoint_h_ != nullptr,
"kRequestCheckpointNotify handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
distributed::BRPCVariableResponse resp(request_checkpoint_h_->scope(),
request_checkpoint_h_->dev_ctx());
auto scope = resp.GetMutableLocalScope();
std::string checkpoint_notify = request->varname();
std::string checkpoint_dir = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
request_checkpoint_h_->Handle(checkpoint_notify, scope, nullptr, nullptr,
trainer_id, checkpoint_dir);
}
void GetMonomerVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_handler_h_ != nullptr,
"kRequestGetMonomerVariable handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
// proc request.
std::string varname = request->varname();
VLOG(3) << "GetMonomerVariable " << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
auto scope = h.scope_;
auto invar = scope->FindVar(varname);
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_handler_h_->Handle(varname, scope, invar, &outvar,
request->trainer_id());
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *h.dev_ctx_, response,
&cntl->response_attachment(), "", false);
}
}
void GetMonomerBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VoidMessage* response,
google::protobuf::Closure* done) override {
PADDLE_ENFORCE(
request_get_monomer_barrier_handler_h_ != nullptr,
"RequestGetMonomerBarrier handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
VLOG(3) << "RequestGetMonomerBarrier var_name:" << varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
rpc_server_->WaitVarCond(varname);
distributed::MonomerHandle h = rpc_server_->GetMonomer(varname);
paddle::framework::Scope* scope = nullptr;
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_monomer_barrier_handler_h_->Handle(
varname, scope, invar, &outvar, request->trainer_id());
} }
private: private:
paddle::operators::distributed::RequestHandler* request_send_h_; distributed::RequestHandler* request_send_h_{nullptr};
paddle::operators::distributed::RequestHandler* request_get_h_; distributed::RequestHandler* request_get_h_{nullptr};
paddle::operators::distributed::RequestHandler* request_prefetch_h_; distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
distributed::RequestHandler* request_get_monomer_barrier_handler_h_{nullptr};
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
}; };
} // namespace sendrecv } // namespace sendrecv
...@@ -100,7 +303,7 @@ namespace distributed { ...@@ -100,7 +303,7 @@ namespace distributed {
void AsyncBRPCServer::StartServer() { void AsyncBRPCServer::StartServer() {
// Instance of your service. // Instance of your service.
sendrecv::BRPCServiceImpl service_impl(rpc_call_map_); sendrecv::BRPCServiceImpl service_impl(rpc_call_map_, this);
// Add the service into server. Notice the second parameter, because the // Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise // service is put on stack, we don't want server to delete it, otherwise
...@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() { ...@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
} }
brpc::ServerOptions options; brpc::ServerOptions options;
#ifdef PADDLE_WITH_BRPC_RDMA
options.use_rdma = true;
#endif
options.idle_timeout_sec = idle_timeout_s_; options.idle_timeout_sec = idle_timeout_s_;
options.max_concurrency = max_concurrency_; options.max_concurrency = max_concurrency_;
if (server_.Start(bind_address_.c_str(), &options) != 0) { if (server_.Start(bind_address_.c_str(), &options) != 0) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
namespace pb = ::google::protobuf;
using vr = ::sendrecv::VariableMessage;
int BRPCVariableResponse::Parse(Source* source) {
pb::io::ZeroCopyInputStream* input_stream = source->contents();
pb::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (1) {
unsigned int tag = 0;
if (!input.ReadLittleEndian32(&tag)) {
break;
}
uint64_t num_bytes = 0;
if (!input.ReadLittleEndian64(&num_bytes)) {
break;
}
int field = static_cast<int>(tag);
int ret = field == 0 ? -1 : field;
switch (field) {
case vr::kSerializedFieldNumber: {
if (!ProcSerializedField(field, &input, num_bytes)) {
return ret;
}
break;
}
case vr::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return ret;
}
break;
}
default: {
PADDLE_ENFORCE(false, "not surpported %u fieldnumber", field);
return ret;
}
}
}
return 0;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class BRPCSourceWrapper : public Source {
public:
explicit BRPCSourceWrapper(const butil::IOBuf& iobuf) : source_(iobuf) {}
::google::protobuf::io::ZeroCopyInputStream* contents() override {
return &source_;
}
private:
butil::IOBufAsZeroCopyInputStream source_;
};
class BRPCVariableResponse : public VariableResponse {
public:
BRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~BRPCVariableResponse() {}
// parse attachment from iobuf
int Parse(Source* source) override;
int Parse(const butil::IOBuf& iobuf, const sendrecv::VariableMessage& meta) {
BRPCSourceWrapper wrapper(iobuf);
return VariableResponse::Parse(&wrapper, meta);
}
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, ...@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC"; const std::string method = "SendMonomerFetchBarrierRPC";
VarHandlePtr h( VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
VLOG(30) << s->GetVarHandlePtr()->String() << " begin"; VLOG(30) << s->GetVarHandlePtr()->String() << " begin";
......
...@@ -32,13 +32,6 @@ namespace paddle { ...@@ -32,13 +32,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
static void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, const std::string& out_name, ::grpc::ByteBuffer* msg, const std::string& out_name,
......
...@@ -75,6 +75,10 @@ class RPCServer { ...@@ -75,6 +75,10 @@ class RPCServer {
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5); int thread_num = 5);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
}
// Wait util all the clients have reached the barrier for one // Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the // rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a // RequestHandler if you want to run the server/client in a
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor( ...@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory::Copy(cuda_pinned, result->ptr(), memory::Copy(cuda_pinned, result->ptr(),
boost::get<platform::CUDAPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
tensor.data<void>(), copy_size, gpu_dev_ctx.stream()); tensor.data<void>(), copy_size, gpu_dev_ctx.stream());
ctx.Wait(); ctx.Wait();
return TensorPayload(result); return TensorPayload(result);
#else #else
......
...@@ -50,6 +50,13 @@ class TensorPayload final { ...@@ -50,6 +50,13 @@ class TensorPayload final {
size_t memory_size_; size_t memory_size_;
}; };
inline void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload = reinterpret_cast<TensorPayload*>(payload);
delete shared_payload;
}
}
TensorPayload GetTensorPayload(framework::Variable* var, TensorPayload GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -26,10 +26,11 @@ limitations under the License. */ ...@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send"); DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get"); DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch"); DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase { ...@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
} }
if (sync_send) { if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
} }
} }
} }
......
...@@ -81,6 +81,14 @@ bool IsCompiledWithCUDA() { ...@@ -81,6 +81,14 @@ bool IsCompiledWithCUDA() {
#endif #endif
} }
bool IsCompiledWithBrpc() {
#if defined(PADDLE_WITH_BRPC) || defined(PADDLE_WITH_BRPC_RDMA)
return true;
#else
return false;
#endif
}
bool IsCompiledWithDIST() { bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
return true; return true;
...@@ -631,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -631,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle.
[](bool init_p2p) { framework::InitDevices(init_p2p); }); [](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST); m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
......
...@@ -152,6 +152,7 @@ def __bootstrap__(): ...@@ -152,6 +152,7 @@ def __bootstrap__():
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus' 'cudnn_exhaustive_search', 'selected_gpus'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0]) core.init_glog(sys.argv[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册