Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0b1c7d83
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0b1c7d83
编写于
12月 14, 2018
作者:
G
gongweibao
提交者:
GitHub
12月 14, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add brpc serialization support. (#11430)
上级
37c2e245
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
1422 addition
and
153 deletion
+1422
-153
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+3
-1
cmake/external/brpc.cmake
cmake/external/brpc.cmake
+12
-8
cmake/external/gtest.cmake
cmake/external/gtest.cmake
+7
-3
cmake/external/leveldb.cmake
cmake/external/leveldb.cmake
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+6
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+9
-2
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+3
-3
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+19
-12
paddle/fluid/operators/distributed/brpc_client.cc
paddle/fluid/operators/distributed/brpc_client.cc
+313
-58
paddle/fluid/operators/distributed/brpc_client.h
paddle/fluid/operators/distributed/brpc_client.h
+82
-17
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
+84
-0
paddle/fluid/operators/distributed/brpc_rdma_pool.h
paddle/fluid/operators/distributed/brpc_rdma_pool.h
+56
-0
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
+196
-0
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
+49
-0
paddle/fluid/operators/distributed/brpc_serde_test.cc
paddle/fluid/operators/distributed/brpc_serde_test.cc
+175
-0
paddle/fluid/operators/distributed/brpc_server.cc
paddle/fluid/operators/distributed/brpc_server.cc
+235
-29
paddle/fluid/operators/distributed/brpc_variable_response.cc
paddle/fluid/operators/distributed/brpc_variable_response.cc
+73
-0
paddle/fluid/operators/distributed/brpc_variable_response.h
paddle/fluid/operators/distributed/brpc_variable_response.h
+67
-0
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+1
-2
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+0
-7
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+4
-0
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+1
-1
paddle/fluid/operators/distributed/sendrecvop_utils.h
paddle/fluid/operators/distributed/sendrecvop_utils.h
+7
-0
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+4
-3
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+2
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+9
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
未找到文件。
benchmark/fluid/fluid_benchmark.py
浏览文件 @
0b1c7d83
...
...
@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
config
=
distribute_transpiler
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
slice_var_up
=
not
args
.
no_split_var
config
.
min_block_size
=
1048576
t
=
distribute_transpiler
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
,
# NOTE: *MUST* use train_prog, for we are using with guard to
...
...
cmake/external/brpc.cmake
浏览文件 @
0b1c7d83
...
...
@@ -14,14 +14,16 @@
INCLUDE
(
ExternalProject
)
find_library
(
SSL_LIBRARY NAMES ssl
)
find_package
(
OpenSSL REQUIRED
)
message
(
STATUS
"ssl:"
${
OPENSSL_SSL_LIBRARY
}
)
message
(
STATUS
"crypto:"
${
OPENSSL_CRYPTO_LIBRARY
}
)
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
SSL_LIBRARY
}
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
OPENSSL_
SSL_LIBRARY
}
)
find_library
(
CRYPTO_LIBRARY NAMES crypto
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
CRYPTO_LIBRARY
}
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
OPENSSL_CRYPTO_LIBRARY
}
)
SET
(
BRPC_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/brpc
)
SET
(
BRPC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/brpc
)
...
...
@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES
(
${
BRPC_INCLUDE_DIR
}
)
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set
(
prefix_path
"
${
THIRD_PARTY_PATH
}
/install/gflags|
${
THIRD_PARTY_PATH
}
/install/leveldb|
${
THIRD_PARTY_PATH
}
/install/snappy|
${
THIRD_PARTY_PATH
}
/install/gtest|
${
THIRD_PARTY_PATH
}
/install/protobuf|
${
THIRD_PARTY_PATH
}
/install/zlib"
)
set
(
prefix_path
"
${
THIRD_PARTY_PATH
}
/install/gflags|
${
THIRD_PARTY_PATH
}
/install/leveldb|
${
THIRD_PARTY_PATH
}
/install/snappy|
${
THIRD_PARTY_PATH
}
/install/gtest|
${
THIRD_PARTY_PATH
}
/install/protobuf|
${
THIRD_PARTY_PATH
}
/install/zlib
|
${
THIRD_PARTY_PATH
}
/install/glog
"
)
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add
(
extern_brpc
${
EXTERNAL_PROJECT_LOG_ARGS
}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY
"https://github.com/gongweibao/brpc"
GIT_TAG
"
7dc04defad1fd4173aae170c3fcbde131b65155a
"
GIT_TAG
"
e9b67ec1b7458f2af5fae76451afe1e27e01b4b4
"
PREFIX
${
BRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
...
@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_PREFIX_PATH=
${
prefix_path
}
-D
BRPC_
WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=
${
WITH_BRPC_RDMA
}
${
EXTERNAL_OPTIONAL_ARGS
}
...
...
@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY
(
TARGET brpc PROPERTY IMPORTED_LOCATION
${
BRPC_LIBRARIES
}
)
ADD_DEPENDENCIES
(
brpc extern_brpc
)
add_definitions
(
-DBRPC_WITH_GLOG
)
LIST
(
APPEND external_project_dependencies brpc
)
cmake/external/gtest.cmake
浏览文件 @
0b1c7d83
...
...
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
IF
(
WITH_TESTING
)
ENABLE_TESTING
()
#FIXME:(gongwb) Move brpc's gtest dependency.
IF
(
WITH_TESTING
OR
(
WITH_DISTRIBUTE AND NOT WITH_GRPC
))
IF
(
WITH_TESTING
)
ENABLE_TESTING
()
ENDIF
(
WITH_TESTING
)
INCLUDE
(
ExternalProject
)
SET
(
GTEST_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/gtest
)
...
...
@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES
(
gtest_main extern_gtest
)
LIST
(
APPEND external_project_dependencies gtest gtest_main
)
ENDIF
(
WITH_TESTING
)
ENDIF
(
WITH_TESTING
OR
(
WITH_DISTRIBUTE AND NOT WITH_GRPC
)
)
cmake/external/leveldb.cmake
浏览文件 @
0b1c7d83
...
...
@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
LEVELDB_SOURCES_DIR
}
URL
"https://github.com/google/leveldb/archive/v1.18.tar.gz
"
URL_MD5
"73770de34a2a5ab34498d2e05b2b7fa0"
GIT_REPOSITORY
"https://github.com/google/leveldb
"
GIT_TAG v1.18
CONFIGURE_COMMAND
""
BUILD_COMMAND CXXFLAGS=-fPIC make -j
${
NUM_OF_PROCESSOR
}
libleveldb.a
INSTALL_COMMAND mkdir -p
${
LEVELDB_INSTALL_DIR
}
/lib/
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
0b1c7d83
...
...
@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library
(
naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
if
(
WITH_NGRAPH
)
if
(
NOT WIN32
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
0b1c7d83
...
...
@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_DISTRIBUTE
)
if
(
NOT WITH_GRPC
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
endif
()
endif
()
if
(
WITH_GPU
)
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
)
if
(
WITH_DISTRIBUTE
)
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_
g
rpc
)
ddim dynload_cuda selected_rows_functor sendrecvop_rpc
)
else
()
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor
)
...
...
@@ -30,7 +37,7 @@ else()
variable_visitor
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_
g
rpc
)
ddim selected_rows_functor sendrecvop_rpc
)
else
()
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor
)
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
0b1c7d83
...
...
@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
(
0
)
->
SendComplete
();
auto
client
=
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
client
->
SendComplete
();
#endif
}
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
0b1c7d83
...
...
@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_
g
rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory
)
...
...
@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_
g
rpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL
)
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL
)
endif
()
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_
g
rpc memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
else
()
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_
b
rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
brpc_library
(
sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
collective_client.cc collective_server.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_
b
rpc memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
set
(
brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib
)
cc_test
(
b
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS
${
brpc_test_depends
}
selected_rows_functor scope math_function SERIAL
)
endif
()
endif
()
paddle/fluid/operators/distributed/brpc_client.cc
浏览文件 @
0b1c7d83
...
...
@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
DEFINE_int32
(
brpc_channel_num
,
24
,
"Number of channels to send requests connected to one server"
);
DEFINE_int32
(
timeout_ms
,
30000
,
"RPC timeout in milliseconds"
);
DEFINE_int32
(
max_retry
,
3
,
"Max retries(not including the first RPC)"
);
BRPCClient
::~
BRPCClient
()
{
Wait
();
}
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
)
{
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VoidMessage
>
response_guard
(
response
);
// this channel can be used by other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
LOG
(
FATAL
)
<<
"Fail to send SendVar: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
var_h
->
Finish
(
false
);
cls
->
DecreaseReqCount
();
return
;
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
var_h
->
Finish
(
true
);
cls
->
DecreaseReqCount
();
VLOG
(
4
)
<<
"HandleSendResponse from: "
<<
cntl
->
remote_side
()
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
VLOG
(
4
)
<<
"Finish HandleSendResponse"
;
}
bool
BRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
VarHandlePtr
BRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"SendRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
framework
::
AsyncIO
(
[
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch_ptr
,
this
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
sendrecv
::
VariableMessage
request
;
distributed
::
SerializeToIOBuf
(
var_name_val
,
var
,
*
p_ctx
,
&
request
,
&
cntl
->
request_attachment
(),
""
,
false
,
trainer_id_
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
sendrecv
::
VariableMessage
request
;
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
request
,
response
,
done
);
});
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
request
,
response
,
done
);
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
return
true
;
return
var_h
;
}
void
HandleFetchBarrierResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
// this channel can be used other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
LOG
(
FATAL
)
<<
"Fail to get HandleFetchBarrierResponse: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
var_h
->
Finish
(
false
);
cls
->
DecreaseReqCount
();
return
;
}
var_h
->
Finish
(
true
);
cls
->
DecreaseReqCount
();
VLOG
(
4
)
<<
"HandleFetchBarrierResponse from: "
<<
cntl
->
remote_side
()
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
VLOG
(
4
)
<<
"Finish HandleFetchBarrierResponse"
;
}
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
)
{
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
// this channel can be used other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
LOG
(
FATAL
)
<<
"Fail to GetVar: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
cls
->
DecreaseReqCount
();
var_h
->
Finish
(
false
);
return
;
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
VLOG
(
4
)
<<
"HandleGetResponse from: "
<<
cntl
->
remote_side
()
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
framework
::
Variable
*
outvar
=
nullptr
;
int
trainer_id
;
distributed
::
DeserializeFromIOBuf
(
*
response
,
cntl
->
response_attachment
(),
*
var_h
->
ctx
(),
var_h
->
scope
(),
&
outvar
,
&
trainer_id
);
VLOG
(
4
)
<<
"Finish HandleGetResponse"
;
cls
->
DecreaseReqCount
();
var_h
->
Finish
(
true
);
}
bool
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
VarHandlePtr
BRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"GetRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
framework
::
AsyncIO
(
[
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{});
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_trainer_id
(
trainer_id_
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleGetResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
if
(
method_name
==
"GetMonomerVariable"
)
{
ch_ctx
->
stub
->
GetMonomerVariable
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
GetVariable
(
cntl
,
&
req
,
response
,
done
);
}
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
return
true
;
return
var_h
;
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetMonomerVariable"
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"GetMonomerBarrier"
,
var_name
,
time_out
);
}
bool
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
VarHandlePtr
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetVariable"
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"PrefetchRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
sendrecv
::
VariableMessage
req
;
distributed
::
SerializeToIOBuf
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
&
cntl
->
request_attachment
(),
out_var_name_val
,
false
,
0
,
table_name_val
);
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleGetResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{});
ch_ctx
->
stub
->
PrefetchVariable
(
cntl
,
&
req
,
response
,
done
);
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
return
true
;
return
var_h
;
}
void
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
req_count_
++
;
VarHandlePtr
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"BatchBarrierRPC"
,
BATCH_BARRIER_MESSAGE
,
time_out
);
}
void
BRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
VarHandlePtr
BRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
auto
ch_ptr
=
GetChannel
(
ep
);
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
const
std
::
string
method
=
"FetchBarrierRPC"
;
// var handle
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
platform
::
RecordRPCEvent
record_event
(
method
,
nullptr
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleFetchBarrierResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
ch_ctx
->
stub
->
GetVariable
(
cntl
,
&
req
,
response
,
done
);
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
return
var_h
;
}
void
BRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
bool
BRPCClient
::
Wait
()
{
VLOG
(
9
)
<<
"begin to brpcclient wait"
;
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
VLOG
(
9
)
<<
"end to brpcclient wait"
;
return
true
;
}
ChannelQueuePtr
BRPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
VLOG
(
4
)
<<
"begin to GetChannel:"
<<
ep
;
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
VLOG
(
4
)
<<
"end to GetChannel:"
<<
ep
;
return
it
->
second
;
}
}
...
...
@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr
q
(
new
framework
::
BlockingQueue
<
ChannelContextPtr
>
());
brpc
::
ChannelOptions
options
;
#ifdef PADDLE_WITH_BRPC_RDMA
options
.
use_rdma
=
true
;
#endif
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
100
;
// don't use pooled type. the server can't afford that.
options
.
connection_type
=
"single"
;
options
.
connect_timeout_ms
=
1000
;
options
.
timeout_ms
=
FLAGS_timeout_ms
/*milliseconds*/
;
options
.
max_retry
=
FLAGS_max_retry
;
for
(
int
i
=
0
;
i
<
FLAGS_brpc_channel_num
;
++
i
)
{
VLOG
(
1
)
<<
"create "
<<
brpc_channel_num_per_server_
<<
" brpc channels to pserver:"
<<
ep
;
for
(
int
i
=
0
;
i
<
brpc_channel_num_per_server_
;
++
i
)
{
std
::
shared_ptr
<
ChannelContext
>
c
(
new
ChannelContext
());
if
(
c
->
channel
.
Init
(
ep
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
FATAL
)
<<
"Fail to initialize channel"
;
...
...
@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_
[
ep
]
=
q
;
}
VLOG
(
4
)
<<
"end to GetChannel:"
<<
ep
;
return
q
;
}
VarHandlePtr
BRPCClient
::
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"SendCompleteRPC"
,
COMPLETE_MESSAGE
,
time_out
);
}
void
BRPCClient
::
SendComplete
()
{
for
(
auto
&
kv
:
channels_
)
{
AsyncSendComplete
(
kv
.
first
);
}
}
VarHandlePtr
BRPCClient
::
AsyncSendVarMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
sendrecv
::
VariableMessage
&
req
,
int64_t
time_out
)
{
auto
ch_ptr
=
GetChannel
(
ep
);
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
platform
::
RecordRPCEvent
record_event
(
method_name
,
nullptr
);
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method_name
,
req
.
varname
(),
nullptr
,
nullptr
));
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
if
(
method_name
==
"CheckPointNotifyRPC"
)
{
ch_ctx
->
stub
->
CheckpointNotify
(
cntl
,
&
req
,
response
,
done
);
}
else
if
(
method_name
==
"GetMonomerBarrier"
)
{
ch_ctx
->
stub
->
GetMonomerBarrier
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
req
,
response
,
done
);
}
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
return
var_h
;
}
VarHandlePtr
BRPCClient
::
AsyncSendMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
std
::
string
&
message
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
message
);
return
AsyncSendVarMessage
(
ep
,
method_name
,
req
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
CHECKPOINT_SAVE_MESSAGE
);
req
.
set_out_varname
(
dir
);
return
AsyncSendVarMessage
(
ep
,
"CheckPointNotifyRPC"
,
req
,
time_out
);
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_client.h
浏览文件 @
0b1c7d83
...
...
@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
...
@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient
()
{}
virtual
~
BRPCClient
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
Wait
()
override
;
VarHandlePtr
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
Wait
()
override
;
void
SendComplete
()
override
;
private:
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
VarHandlePtr
AsyncSendMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
std
::
string
&
message
,
int64_t
time_out
);
VarHandlePtr
AsyncSendVarMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
sendrecv
::
VariableMessage
&
req
,
int64_t
time_out
);
friend
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
friend
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
friend
void
HandleFetchBarrierResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
void
DecreaseReqCount
()
{
if
(
--
req_count_
<=
0
)
{
sync_cond_
.
notify_all
();
}
}
private:
std
::
unordered_map
<
std
::
string
,
ChannelQueuePtr
>
channels_
;
...
...
@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std
::
condition_variable
sync_cond_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
static
constexpr
int
brpc_channel_num_per_server_
=
4
;
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
DISABLE_COPY_AND_ASSIGN
(
BRPCClient
);
...
...
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
0 → 100644
浏览文件 @
0b1c7d83
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
RdmaMemPool
&
RdmaMemPool
::
Instance
()
{
static
RdmaMemPool
*
g_rdma_mem_pool
=
new
RdmaMemPool
();
return
*
g_rdma_mem_pool
;
}
void
*
RdmaMemPool
::
Find
(
const
std
::
string
&
varname
,
int64_t
size
)
{
pthread_rwlock_rdlock
(
&
access_
);
auto
it
=
pool_
.
find
(
varname
);
if
(
it
==
pool_
.
end
())
{
pthread_rwlock_unlock
(
&
access_
);
return
nullptr
;
}
auto
info
=
it
->
second
;
if
(
info
.
data_size
!=
size
)
{
pthread_rwlock_unlock
(
&
access_
);
PADDLE_ENFORCE
(
false
,
"var:%s size:%ld != %ld"
,
varname
,
size
,
info
.
data_size
);
return
nullptr
;
}
pthread_rwlock_unlock
(
&
access_
);
return
info
.
data
;
}
void
RdmaMemPool
::
Register
(
const
std
::
string
&
varname
,
void
*
data
,
int64_t
data_size
)
{
void
*
old
=
Find
(
varname
,
data_size
);
if
(
old
!=
nullptr
)
{
if
(
data
!=
old
)
{
PADDLE_ENFORCE
(
false
,
"var:%s data:%ld != %ld"
,
varname
,
data
,
old
);
}
VLOG
(
7
)
<<
"Find on rdma:"
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
;
return
;
}
VarInfo
info
;
info
.
data
=
data
;
info
.
data_size
=
data_size
;
pthread_rwlock_wrlock
(
&
access_
);
pool_
[
varname
]
=
info
;
pthread_rwlock_unlock
(
&
access_
);
if
(
brpc
::
rdma
::
RegisterMemoryForRdma
(
data
,
data_size
))
{
LOG
(
FATAL
)
<<
"register "
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
<<
" error"
;
}
VLOG
(
4
)
<<
"register on rdma:"
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/distributed/brpc_rdma_pool.h
0 → 100644
浏览文件 @
0b1c7d83
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class
RdmaMemPool
{
public:
static
RdmaMemPool
&
Instance
();
RdmaMemPool
()
:
access_
(
PTHREAD_RWLOCK_INITIALIZER
)
{}
virtual
~
RdmaMemPool
()
{
pthread_rwlock_destroy
(
&
access_
);
}
void
Register
(
const
std
::
string
&
varname
,
void
*
data
,
int64_t
size
);
void
*
Find
(
const
std
::
string
&
varname
,
int64_t
size
);
private:
struct
VarInfo
{
void
*
data
;
int64_t
data_size
;
VarInfo
()
:
data
(
nullptr
),
data_size
(
0
)
{}
};
private:
std
::
unordered_map
<
std
::
string
,
VarInfo
>
pool_
;
pthread_rwlock_t
access_
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
0 → 100644
浏览文件 @
0b1c7d83
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
IOBufWriter
{
public:
static
void
Append
(
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
)
{
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
iobuf
->
append
(
v
,
vlen
);
}
static
void
AppendTCPZeroCopy
(
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
VLOG
(
7
)
<<
"AppendTCPZeroCopy "
<<
" k:"
<<
k
<<
" data:"
<<
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
))
<<
" data_size:"
<<
vlen
<<
" in_cuda_pinned:"
<<
in_cuda_pinned
;
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf
->
append
(
v
,
vlen
);
destroy
(
user_data
);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static
void
AppendRdmaZeroCopy
(
const
std
::
string
varname
,
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
VLOG
(
7
)
<<
"AppendRdmaZeroCopy varname:"
<<
varname
<<
" k:"
<<
k
<<
" data:"
<<
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
))
<<
" data_size:"
<<
vlen
<<
" in_cuda_pinned:"
<<
in_cuda_pinned
;
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
RdmaMemPool
::
Instance
().
Register
(
varname
,
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
)),
vlen
);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf
->
append
(
v
,
vlen
);
destroy
(
user_data
);
return
;
}
#endif
static
void
AppendZeroCopy
(
const
std
::
string
varname
,
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter
::
AppendRdmaZeroCopy
(
varname
,
iobuf
,
k
,
v
,
vlen
,
in_cuda_pinned
,
destroy
,
user_data
);
#else
IOBufWriter
::
AppendTCPZeroCopy
(
iobuf
,
k
,
v
,
vlen
,
in_cuda_pinned
,
destroy
,
user_data
);
#endif
}
};
void
SerializeToIOBuf
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
,
const
std
::
string
&
out_varname
,
bool
var_is_not_stable
,
int
trainer_id
,
const
std
::
string
&
table_name
)
{
std
::
unique_ptr
<
TensorPayload
>
payload
;
request
->
set_varname
(
name
);
request
->
set_trainer_id
(
trainer_id
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
->
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
->
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_varname
.
empty
())
{
request
->
set_out_varname
(
out_varname
);
}
if
(
!
table_name
.
empty
())
{
request
->
set_table_name
(
table_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
->
set_type
(
::
sendrecv
::
LOD_TENSOR
);
payload
.
reset
(
new
TensorPayload
(
GetTensorPayload
(
var
,
ctx
,
request
)));
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
->
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
payload
.
reset
(
new
TensorPayload
(
GetSelectedRowsPayload
(
var
,
ctx
,
request
)));
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
->
set_type
(
::
sendrecv
::
NCCL_ID
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter
::
Append
(
iobuf
,
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
);
return
;
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
PADDLE_ENFORCE_NOT_NULL
(
payload
);
// FIXME(gongwb): it seems that can use zero copy.
if
(
var_is_not_stable
)
{
IOBufWriter
::
Append
(
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
());
}
else
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
IOBufWriter
::
AppendZeroCopy
(
name
,
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
(),
true
,
SerializeDestroyCallback
,
static_cast
<
void
*>
(
payload
.
get
()));
payload
.
release
();
#endif
}
else
{
IOBufWriter
::
AppendZeroCopy
(
name
,
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
(),
false
,
SerializeDestroyCallback
,
static_cast
<
void
*>
(
payload
.
get
()));
payload
.
release
();
}
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
IOBufWriter
::
Append
(
iobuf
,
::
sendrecv
::
VariableMessage
::
kRowsFieldNumber
,
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()),
static_cast
<
int64_t
>
(
rows_memory_size
));
}
}
void
DeserializeFromIOBuf
(
const
::
sendrecv
::
VariableMessage
&
meta
,
const
butil
::
IOBuf
&
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
operators
::
distributed
::
BRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
iobuf
,
meta
)
==
0
,
"parse iobuf to tensor error!"
);
*
var
=
resp
.
GetVar
();
*
trainer_id
=
resp
.
GetTrainerId
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
0 → 100644
浏览文件 @
0b1c7d83
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
SerializeToIOBuf
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
,
const
std
::
string
&
out_varname
,
bool
var_is_not_stable
,
const
int
trainer_id
=
0
,
const
std
::
string
&
table_name
=
std
::
string
());
void
DeserializeFromIOBuf
(
const
VarMsg
&
meta
,
const
butil
::
IOBuf
&
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_serde_test.cc
0 → 100644
浏览文件 @
0b1c7d83
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
operators
=
paddle
::
operators
;
namespace
math
=
paddle
::
operators
::
math
;
namespace
memory
=
paddle
::
memory
;
void
RunSerdeTestSelectedRows
(
platform
::
Place
place
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
butil
::
IOBuf
iobuf
;
sendrecv
::
VariableMessage
msg
;
int
tensor_numel
=
564
*
128
;
// serialize var to IOBuf
{
framework
::
Variable
var
;
auto
*
slr
=
var
.
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
set_height
(
1000
);
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
tensor
->
Resize
(
framework
::
make_ddim
({
564
,
128
}));
tensor
->
mutable_data
<
float
>
(
place
);
math
::
set_constant
(
ctx
,
tensor
,
32.7
);
for
(
int
i
=
0
;
i
<
564
;
++
i
)
rows
->
push_back
(
i
);
operators
::
distributed
::
SerializeToIOBuf
(
"myvar"
,
&
var
,
ctx
,
&
msg
,
&
iobuf
,
""
,
false
);
}
// desrialize
{
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
BRPCVariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
iobuf
,
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
*
slr2
=
var2
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor2
=
slr2
->
mutable_value
();
auto
*
rows2
=
slr2
->
mutable_rows
();
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
*
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
->
data
<
float
>
());
}
const
int64_t
*
rows_data2
=
rows2
->
data
();
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
}
for
(
size_t
i
=
0
;
i
<
rows2
->
size
();
++
i
)
{
EXPECT_EQ
(
rows_data2
[
i
],
static_cast
<
int64_t
>
(
i
));
}
EXPECT_EQ
(
slr2
->
height
(),
1000
);
}
}
void
RunTestLodTensor
(
platform
::
Place
place
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// serialize var to ByteBuffer
butil
::
IOBuf
iobuf
;
sendrecv
::
VariableMessage
msg
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
{
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
tensor
->
mutable_data
<
float
>
(
place
);
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
operators
::
distributed
::
SerializeToIOBuf
(
"myvar"
,
&
var
,
ctx
,
&
msg
,
&
iobuf
,
""
,
false
);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ
(
msg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
msg
.
type
(),
0
);
EXPECT_EQ
(
msg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
msg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
msg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
msg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
msg
.
lod_level
(),
1
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
0
),
1
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
1
),
3
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
2
),
8
);
}
// deserialize
{
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
BRPCVariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
iobuf
,
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
tensor2
=
var2
->
Get
<
framework
::
LoDTensor
>
();
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
.
data
<
float
>
());
}
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
31.9
);
}
}
TEST
(
LodTensor
,
Run
)
{
platform
::
CPUPlace
place
;
RunTestLodTensor
(
place
);
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPlace
gpu
(
0
);
RunTestLodTensor
(
gpu
);
#endif
}
TEST
(
SelectedRows
,
Run
)
{
platform
::
CPUPlace
place
;
RunSerdeTestSelectedRows
(
place
);
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPlace
gpu
;
RunSerdeTestSelectedRows
(
gpu
);
#endif
}
paddle/fluid/operators/distributed/brpc_server.cc
浏览文件 @
0b1c7d83
...
...
@@ -13,84 +13,287 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace
sendrecv
{
typedef
std
::
unordered_map
<
std
::
string
,
paddle
::
operators
::
distributed
::
RequestHandler
*>
namespace
distributed
=
paddle
::
operators
::
distributed
;
typedef
std
::
unordered_map
<
std
::
string
,
distributed
::
RequestHandler
*>
HandlerMap
;
class
BRPCServiceImpl
:
public
SendRecvService
{
public:
explicit
BRPCServiceImpl
(
const
HandlerMap
&
rpc_call_map
)
:
request_send_h_
(
nullptr
),
request_get_h_
(
nullptr
),
request_prefetch_h_
(
nullptr
)
{
auto
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestSend
);
explicit
BRPCServiceImpl
(
const
HandlerMap
&
rpc_call_map
,
distributed
::
RPCServer
*
rpc_server
)
:
rpc_server_
(
rpc_server
)
{
VLOG
(
3
)
<<
"BRPCServiceImpl size: "
<<
rpc_call_map
.
size
();
auto
it
=
rpc_call_map
.
find
(
distributed
::
kRequestSend
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_send_h_
=
it
->
second
;
send_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestSend
)));
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestSend
);
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGet
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_h_
=
it
->
second
;
get_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestGet
)));
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestPrefetch
);
it
=
rpc_call_map
.
find
(
distributed
::
kRequestPrefetch
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_prefetch_h_
=
it
->
second
;
prefetch_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestPrefetch
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestCheckpoint
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_checkpoint_h_
=
it
->
second
;
checkpoint_notify_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestPrefetch
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGetMonomerVariable
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_monomer_handler_h_
=
it
->
second
;
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGetMonomerBarrier
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_monomer_barrier_handler_h_
=
it
->
second
;
}
}
virtual
~
BRPCServiceImpl
()
{}
void
SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
send_threads_
->
Run
(
[
=
]
{
_SendVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_send_h_
!=
nullptr
,
"RequestSend handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
paddle
::
framework
::
Scope
*
local_scope
=
request_send_h_
->
scope
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestSend var_name:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
if
(
!
request_send_h_
->
sync_mode
())
{
local_scope
=
&
request_send_h_
->
scope
()
->
NewScope
();
invar
=
local_scope
->
Var
(
varname
);
}
else
{
invar
=
local_scope
->
FindVar
(
varname
);
}
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
request_send_h_
->
dev_ctx
(),
!
request_send_h_
->
sync_mode
());
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
"parse iobuf to tensor error!"
);
request_send_h_
->
Handle
(
varname
,
local_scope
,
invar
,
&
outvar
);
auto
scope
=
resp
.
GetMutableLocalScope
();
auto
invar
=
resp
.
GetVar
();
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
if
(
!
request_send_h_
->
sync_mode
())
{
request_send_h_
->
scope
()
->
DeleteScope
(
local_scope
);
}
request_send_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
}
void
GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
get_threads_
->
Run
(
[
=
]
{
_GetVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_get_h_
!=
nullptr
,
"RequestGet handler should be registed first!"
);
}
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestGet varname:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
auto
scope
=
request_get_h_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
varname
,
outvar
,
*
request_get_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
prefetch_threads_
->
Run
(
[
=
]
{
_PrefetchVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_prefetch_h_
!=
nullptr
,
"kRequestPrefetch handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
// prefetch process...
std
::
string
in_var_name
=
request
->
varname
();
std
::
string
out_var_name
=
request
->
out_varname
();
VLOG
(
3
)
<<
"RequestPrefetch, in_var_name: "
<<
in_var_name
<<
", out_var_name: "
<<
out_var_name
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
distributed
::
BRPCVariableResponse
resp
(
request_prefetch_h_
->
scope
(),
request_prefetch_h_
->
dev_ctx
(),
true
);
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
"parse iobuf to tensor error!"
);
auto
scope
=
resp
.
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
in_var_name
);
std
::
string
table_name
=
request
->
table_name
();
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
scope
->
Var
(
out_var_name
);
request_prefetch_h_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_var_name
,
table_name
);
distributed
::
SerializeToIOBuf
(
out_var_name
,
outvar
,
*
request_prefetch_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
true
);
}
void
CheckpointNotify
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
checkpoint_notify_threads_
->
Run
(
[
=
]
{
_CheckpointNotify
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_CheckpointNotify
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_checkpoint_h_
!=
nullptr
,
"kRequestCheckpointNotify handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
distributed
::
BRPCVariableResponse
resp
(
request_checkpoint_h_
->
scope
(),
request_checkpoint_h_
->
dev_ctx
());
auto
scope
=
resp
.
GetMutableLocalScope
();
std
::
string
checkpoint_notify
=
request
->
varname
();
std
::
string
checkpoint_dir
=
request
->
out_varname
();
int
trainer_id
=
request
->
trainer_id
();
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
request_checkpoint_h_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
trainer_id
,
checkpoint_dir
);
}
void
GetMonomerVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_get_monomer_handler_h_
!=
nullptr
,
"kRequestGetMonomerVariable handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
// proc request.
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"GetMonomerVariable "
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
rpc_server_
->
WaitVarCond
(
varname
);
distributed
::
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
auto
scope
=
h
.
scope_
;
auto
invar
=
scope
->
FindVar
(
varname
);
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_monomer_handler_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request
->
trainer_id
());
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
varname
,
outvar
,
*
h
.
dev_ctx_
,
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
GetMonomerBarrier
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_get_monomer_barrier_handler_h_
!=
nullptr
,
"RequestGetMonomerBarrier handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestGetMonomerBarrier var_name:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
rpc_server_
->
WaitVarCond
(
varname
);
distributed
::
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
paddle
::
framework
::
Scope
*
scope
=
nullptr
;
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_monomer_barrier_handler_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request
->
trainer_id
());
}
private:
paddle
::
operators
::
distributed
::
RequestHandler
*
request_send_h_
;
paddle
::
operators
::
distributed
::
RequestHandler
*
request_get_h_
;
paddle
::
operators
::
distributed
::
RequestHandler
*
request_prefetch_h_
;
distributed
::
RequestHandler
*
request_send_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_prefetch_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_checkpoint_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_monomer_handler_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_monomer_barrier_handler_h_
{
nullptr
};
distributed
::
RPCServer
*
rpc_server_
{
nullptr
};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
send_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
get_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
prefetch_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
checkpoint_notify_threads_
;
};
}
// namespace sendrecv
...
...
@@ -100,7 +303,7 @@ namespace distributed {
void
AsyncBRPCServer
::
StartServer
()
{
// Instance of your service.
sendrecv
::
BRPCServiceImpl
service_impl
(
rpc_call_map_
);
sendrecv
::
BRPCServiceImpl
service_impl
(
rpc_call_map_
,
this
);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
...
...
@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
}
brpc
::
ServerOptions
options
;
#ifdef PADDLE_WITH_BRPC_RDMA
options
.
use_rdma
=
true
;
#endif
options
.
idle_timeout_sec
=
idle_timeout_s_
;
options
.
max_concurrency
=
max_concurrency_
;
if
(
server_
.
Start
(
bind_address_
.
c_str
(),
&
options
)
!=
0
)
{
...
...
paddle/fluid/operators/distributed/brpc_variable_response.cc
0 → 100644
浏览文件 @
0b1c7d83
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
namespace
pb
=
::
google
::
protobuf
;
using
vr
=
::
sendrecv
::
VariableMessage
;
int
BRPCVariableResponse
::
Parse
(
Source
*
source
)
{
pb
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
pb
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
1
)
{
unsigned
int
tag
=
0
;
if
(
!
input
.
ReadLittleEndian32
(
&
tag
))
{
break
;
}
uint64_t
num_bytes
=
0
;
if
(
!
input
.
ReadLittleEndian64
(
&
num_bytes
))
{
break
;
}
int
field
=
static_cast
<
int
>
(
tag
);
int
ret
=
field
==
0
?
-
1
:
field
;
switch
(
field
)
{
case
vr
::
kSerializedFieldNumber
:
{
if
(
!
ProcSerializedField
(
field
,
&
input
,
num_bytes
))
{
return
ret
;
}
break
;
}
case
vr
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
ret
;
}
break
;
}
default:
{
PADDLE_ENFORCE
(
false
,
"not surpported %u fieldnumber"
,
field
);
return
ret
;
}
}
}
return
0
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_variable_response.h
0 → 100644
浏览文件 @
0b1c7d83
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
BRPCSourceWrapper
:
public
Source
{
public:
explicit
BRPCSourceWrapper
(
const
butil
::
IOBuf
&
iobuf
)
:
source_
(
iobuf
)
{}
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
override
{
return
&
source_
;
}
private:
butil
::
IOBufAsZeroCopyInputStream
source_
;
};
class
BRPCVariableResponse
:
public
VariableResponse
{
public:
BRPCVariableResponse
(
const
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
bool
create_scope
=
false
)
:
VariableResponse
(
scope
,
dev_ctx
,
create_scope
)
{}
virtual
~
BRPCVariableResponse
()
{}
// parse attachment from iobuf
int
Parse
(
Source
*
source
)
override
;
int
Parse
(
const
butil
::
IOBuf
&
iobuf
,
const
sendrecv
::
VariableMessage
&
meta
)
{
BRPCSourceWrapper
wrapper
(
iobuf
);
return
VariableResponse
::
Parse
(
&
wrapper
,
meta
);
}
};
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
0b1c7d83
...
...
@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
VLOG
(
30
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
0b1c7d83
...
...
@@ -32,13 +32,6 @@ namespace paddle {
namespace
operators
{
namespace
distributed
{
static
void
SerializeDestroyCallback
(
void
*
payload
)
{
if
(
payload
!=
nullptr
)
{
auto
*
shared_payload
=
reinterpret_cast
<
TensorPayload
*>
(
payload
);
delete
shared_payload
;
}
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
0b1c7d83
...
...
@@ -75,6 +75,10 @@ class RPCServer {
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
int
GetThreadNum
(
const
std
::
string
&
rpc_name
)
{
return
rpc_thread_num_
[
rpc_name
];
}
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
0b1c7d83
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h"
...
...
@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory
::
Copy
(
cuda_pinned
,
result
->
ptr
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
tensor
.
data
<
void
>
(),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
return
TensorPayload
(
result
);
#else
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.h
浏览文件 @
0b1c7d83
...
...
@@ -50,6 +50,13 @@ class TensorPayload final {
size_t
memory_size_
;
};
inline
void
SerializeDestroyCallback
(
void
*
payload
)
{
if
(
payload
!=
nullptr
)
{
auto
*
shared_payload
=
reinterpret_cast
<
TensorPayload
*>
(
payload
);
delete
shared_payload
;
}
}
TensorPayload
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
);
...
...
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
0b1c7d83
...
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_
b
rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
0b1c7d83
...
...
@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
rpc_send_thread_num
,
5
,
"number of threads for rpc send"
);
DEFINE_int32
(
rpc_get_thread_num
,
5
,
"number of threads for rpc get"
);
DEFINE_int32
(
rpc_prefetch_thread_num
,
5
,
"number of threads for rpc prefetch"
);
DEFINE_int32
(
rpc_send_thread_num
,
12
,
"number of threads for rpc send"
);
DEFINE_int32
(
rpc_get_thread_num
,
12
,
"number of threads for rpc get"
);
DEFINE_int32
(
rpc_prefetch_thread_num
,
12
,
"number of threads for rpc prefetch"
);
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
0b1c7d83
...
...
@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
}
if
(
sync_send
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
VLOG
(
7
)
<<
"before sync_send "
<<
ins
[
i
]
<<
"from "
<<
epmap
[
i
];
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
VLOG
(
7
)
<<
"after sync_send "
<<
ins
[
i
]
<<
"from "
<<
epmap
[
i
];
}
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
0b1c7d83
...
...
@@ -81,6 +81,14 @@ bool IsCompiledWithCUDA() {
#endif
}
bool
IsCompiledWithBrpc
()
{
#if defined(PADDLE_WITH_BRPC) || defined(PADDLE_WITH_BRPC_RDMA)
return
true
;
#else
return
false
;
#endif
}
bool
IsCompiledWithDIST
()
{
#ifdef PADDLE_WITH_DISTRIBUTE
return
true
;
...
...
@@ -631,6 +639,7 @@ All parameter, weight, gradient are variables in Paddle.
[](
bool
init_p2p
)
{
framework
::
InitDevices
(
init_p2p
);
});
m
.
def
(
"is_compiled_with_cuda"
,
IsCompiledWithCUDA
);
m
.
def
(
"is_compiled_with_brpc"
,
IsCompiledWithBrpc
);
m
.
def
(
"is_compiled_with_dist"
,
IsCompiledWithDIST
);
#ifdef PADDLE_WITH_CUDA
m
.
def
(
"is_float16_supported"
,
[](
const
platform
::
CUDAPlace
&
place
)
->
bool
{
...
...
python/paddle/fluid/__init__.py
浏览文件 @
0b1c7d83
...
...
@@ -152,6 +152,7 @@ def __bootstrap__():
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
,
'selected_gpus'
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
[
"--tryfromenv="
+
","
.
join
(
read_env_flags
)])
core
.
init_glog
(
sys
.
argv
[
0
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录