Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d9de6b86
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d9de6b86
编写于
6月 11, 2018
作者:
G
gongweibao
提交者:
GitHub
6月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add brpc surpport. (#11263)
上级
bfa3fd6f
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
748 addition
and
71 deletion
+748
-71
CMakeLists.txt
CMakeLists.txt
+12
-2
cmake/configure.cmake
cmake/configure.cmake
+4
-0
cmake/external/brpc.cmake
cmake/external/brpc.cmake
+58
-0
cmake/external/leveldb.cmake
cmake/external/leveldb.cmake
+44
-0
cmake/generic.cmake
cmake/generic.cmake
+18
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+2
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+13
-3
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+30
-4
paddle/fluid/operators/detail/brpc_client.cc
paddle/fluid/operators/detail/brpc_client.cc
+180
-0
paddle/fluid/operators/detail/brpc_client.h
paddle/fluid/operators/detail/brpc_client.h
+100
-0
paddle/fluid/operators/detail/brpc_server.cc
paddle/fluid/operators/detail/brpc_server.cc
+144
-0
paddle/fluid/operators/detail/brpc_server.h
paddle/fluid/operators/detail/brpc_server.h
+53
-0
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+1
-0
paddle/fluid/operators/detail/grpc_serde_test.cc
paddle/fluid/operators/detail/grpc_serde_test.cc
+0
-0
paddle/fluid/operators/detail/macros.h
paddle/fluid/operators/detail/macros.h
+27
-0
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+4
-1
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+0
-3
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+0
-1
paddle/fluid/operators/detail/rpc_client.h
paddle/fluid/operators/detail/rpc_client.h
+2
-0
paddle/fluid/operators/detail/rpc_server_test.cc
paddle/fluid/operators/detail/rpc_server_test.cc
+7
-9
paddle/fluid/operators/detail/send_recv.proto
paddle/fluid/operators/detail/send_recv.proto
+2
-0
paddle/fluid/operators/detail/sendrecvop_utils.h
paddle/fluid/operators/detail/sendrecvop_utils.h
+0
-10
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+2
-4
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+13
-11
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+12
-5
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+2
-2
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+2
-3
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-2
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+2
-2
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+12
-8
未找到文件。
CMakeLists.txt
浏览文件 @
d9de6b86
...
...
@@ -55,12 +55,13 @@ option(WITH_FLUID_ONLY "Compile PaddlePaddle fluid only" OFF)
option
(
WITH_GOLANG
"Compile PaddlePaddle with GOLANG"
OFF
)
option
(
GLIDE_INSTALL
"Download and install go dependencies "
ON
)
option
(
USE_NNPACK
"Compile PaddlePaddle with NNPACK library"
OFF
)
option
(
WITH_DISTRIBUTE
"Compile with
grpc distributed support"
OFF
)
option
(
WITH_DISTRIBUTE
"Compile with
distributed support"
OFF
)
option
(
USE_EIGEN_FOR_BLAS
"Use matrix multiplication in Eigen"
OFF
)
option
(
EIGEN_USE_THREADS
"Compile with multi-threaded Eigen"
OFF
)
option
(
WITH_ARM_FP16
"Use half precision support on armv8.2-a cpu"
OFF
)
option
(
WITH_FAST_BUNDLE_TEST
"Bundle tests that can be run in a single process together to reduce launch overhead"
OFF
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
WITH_GRPC
"Use grpc as the default rpc framework"
${
WITH_DISTRIBUTE
}
)
# CMAKE_BUILD_TYPE
if
(
NOT CMAKE_BUILD_TYPE
)
...
...
@@ -147,7 +148,16 @@ include(external/any) # download libn::any
include
(
external/eigen
)
# download eigen3
include
(
external/pybind11
)
# download pybind11
include
(
external/cares
)
include
(
external/grpc
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_GRPC
)
include
(
external/grpc
)
else
()
include
(
external/leveldb
)
include
(
external/brpc
)
endif
()
endif
()
include
(
external/snappy
)
# download snappy
include
(
external/snappystream
)
include
(
external/threadpool
)
...
...
cmake/configure.cmake
浏览文件 @
d9de6b86
...
...
@@ -166,3 +166,7 @@ if(WITH_GOLANG)
endif
()
endif
(
WITH_GOLANG
)
if
(
WITH_GRPC
)
add_definitions
(
-DPADDLE_WITH_GRPC
)
endif
(
WITH_GRPC
)
cmake/external/brpc.cmake
0 → 100644
浏览文件 @
d9de6b86
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE
(
ExternalProject
)
SET
(
BRPC_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/brpc
)
SET
(
BRPC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/brpc
)
SET
(
BRPC_INCLUDE_DIR
"
${
BRPC_INSTALL_DIR
}
/include"
CACHE PATH
"brpc include directory."
FORCE
)
SET
(
BRPC_LIBRARIES
"
${
BRPC_INSTALL_DIR
}
/lib/libbrpc.a"
CACHE FILEPATH
"brpc library."
FORCE
)
INCLUDE_DIRECTORIES
(
${
BRPC_INCLUDE_DIR
}
)
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set
(
prefix_path
"
${
THIRD_PARTY_PATH
}
/install/gflags|
${
THIRD_PARTY_PATH
}
/install/leveldb|
${
THIRD_PARTY_PATH
}
/install/snappy|
${
THIRD_PARTY_PATH
}
/install/gtest|
${
THIRD_PARTY_PATH
}
/install/protobuf"
)
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add
(
extern_brpc
${
EXTERNAL_PROJECT_LOG_ARGS
}
GIT_REPOSITORY
"https://github.com/brpc/brpc"
GIT_TAG
"6d153dd7ff00f960ae6895c9c5fff0ce9f07aff2"
PREFIX
${
BRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_INSTALL_PREFIX=
${
BRPC_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR=
${
BRPC_INSTALL_DIR
}
/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_PREFIX_PATH=
${
prefix_path
}
-DBRPC_WITH_GLOG=ON
${
EXTERNAL_OPTIONAL_ARGS
}
LIST_SEPARATOR |
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
BRPC_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR:PATH=
${
BRPC_INSTALL_DIR
}
/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
)
ADD_DEPENDENCIES
(
extern_brpc protobuf leveldb gflags glog gtest snappy
)
ADD_LIBRARY
(
brpc STATIC IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET brpc PROPERTY IMPORTED_LOCATION
${
BRPC_LIBRARIES
}
)
ADD_DEPENDENCIES
(
brpc extern_brpc
)
LIST
(
APPEND external_project_dependencies brpc
)
cmake/external/leveldb.cmake
0 → 100644
浏览文件 @
d9de6b86
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE
(
ExternalProject
)
SET
(
LEVELDB_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/leveldb
)
SET
(
LEVELDB_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/leveldb
)
SET
(
LEVELDB_INCLUDE_DIR
"
${
LEVELDB_INSTALL_DIR
}
/include"
CACHE PATH
"leveldb include directory."
FORCE
)
SET
(
LEVELDB_LIBRARIES
"
${
LEVELDB_INSTALL_DIR
}
/lib/libleveldb.a"
CACHE FILEPATH
"leveldb library."
FORCE
)
INCLUDE_DIRECTORIES
(
${
LEVELDB_INCLUDE_DIR
}
)
ExternalProject_Add
(
extern_leveldb
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
LEVELDB_SOURCES_DIR
}
URL
"https://github.com/google/leveldb/archive/v1.18.tar.gz"
URL_MD5
"73770de34a2a5ab34498d2e05b2b7fa0"
CONFIGURE_COMMAND
""
BUILD_COMMAND CXXFLAGS=-fPIC make -j
${
NUM_OF_PROCESSOR
}
libleveldb.a
INSTALL_COMMAND mkdir -p
${
LEVELDB_INSTALL_DIR
}
/lib/
&& cp
${
LEVELDB_SOURCES_DIR
}
/src/extern_leveldb/libleveldb.a
${
LEVELDB_LIBRARIES
}
&& cp -r
${
LEVELDB_SOURCES_DIR
}
/src/extern_leveldb/include
${
LEVELDB_INSTALL_DIR
}
/
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES
(
extern_leveldb snappy
)
ADD_LIBRARY
(
leveldb STATIC IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET leveldb PROPERTY IMPORTED_LOCATION
${
LEVELDB_LIBRARIES
}
)
ADD_DEPENDENCIES
(
leveldb extern_leveldb
)
LIST
(
APPEND external_project_dependencies leveldb
)
cmake/generic.cmake
浏览文件 @
d9de6b86
...
...
@@ -610,3 +610,21 @@ function(grpc_library TARGET_NAME)
COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
cc_library
(
"
${
TARGET_NAME
}
"
SRCS
"
${
grpc_library_SRCS
}
"
DEPS
"
${
TARGET_NAME
}
_grpc"
"
${
TARGET_NAME
}
_proto"
"
${
grpc_library_DEPS
}
"
)
endfunction
()
function
(
brpc_library TARGET_NAME
)
set
(
oneValueArgs PROTO
)
set
(
multiValueArgs SRCS DEPS
)
set
(
options
""
)
cmake_parse_arguments
(
brpc_library
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
message
(
STATUS
"generating brpc
${
brpc_library_PROTO
}
"
)
get_filename_component
(
ABS_PROTO
${
brpc_library_PROTO
}
ABSOLUTE
)
get_filename_component
(
PROTO_WE
${
brpc_library_PROTO
}
NAME_WE
)
get_filename_component
(
PROTO_PATH
${
ABS_PROTO
}
PATH
)
protobuf_generate_cpp
(
brpc_proto_srcs brpc_proto_hdrs
"
${
ABS_PROTO
}
"
)
cc_library
(
"
${
TARGET_NAME
}
_proto"
SRCS
"
${
brpc_proto_srcs
}
"
)
cc_library
(
"
${
TARGET_NAME
}
"
SRCS
"
${
brpc_library_SRCS
}
"
DEPS
"
${
TARGET_NAME
}
_proto"
"
${
brpc_library_DEPS
}
"
)
endfunction
()
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
d9de6b86
...
...
@@ -64,7 +64,8 @@ class TRTConvertValidation {
TRTConvertValidation
(
int
batch_size
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
framework
::
Scope
&
scope
,
int
workspace_size
=
1
<<
10
)
framework
::
Scope
&
scope
,
// NOLINT
int
workspace_size
=
1
<<
10
)
:
parameters_
(
parameters
),
scope_
(
scope
)
{
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
d9de6b86
...
...
@@ -186,8 +186,14 @@ endif()
add_subdirectory
(
detail
)
if
(
WITH_DISTRIBUTE
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib
)
endif
()
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
op_library
(
prefetch_op DEPS
${
DISTRIBUTE_DEPS
}
)
set_source_files_properties
(
prefetch_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
@@ -207,7 +213,11 @@ if(WITH_DISTRIBUTE)
if
(
WITH_GPU
)
set_source_files_properties
(
test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op executor SERIAL
)
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_grpc
)
if
(
WITH_GRPC
)
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_grpc
)
else
()
op_library
(
gen_nccl_id_op DEPS nccl_common sendrecvop_brpc
)
endif
()
set_source_files_properties
(
gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
set
(
DEPS_OPS
${
DEPS_OPS
}
gen_nccl_id_op
)
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
d9de6b86
if
(
WITH_DISTRIBUTE
)
if
(
NOT WITH_DISTRIBUTE
)
return
()
endif
()
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc g
rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
set_source_files_properties
(
grpc_serde_test.cc
rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS
grpc_
serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc SERIAL
)
cc_test
(
grpc_server_test SRCS
g
rpc_server_test.cc DEPS sendrecvop_grpc
cc_test
(
grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
proto_desc lookup_table_op SERIAL
)
return
()
endif
()
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
find_library
(
OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
OPENSSL_CRYPTO_LIBRARY_STATIC
}
)
find_library
(
OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so
)
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
OPENSSL_SSL_LIBRARY_STATIC
}
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
brpc protobuf leveldb gflags glog
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL
)
paddle/fluid/operators/detail/brpc_client.cc
0 → 100644
浏览文件 @
d9de6b86
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
DEFINE_int32
(
brpc_channel_num
,
24
,
"Number of channels to send requests connected to one server"
);
DEFINE_int32
(
timeout_ms
,
30000
,
"RPC timeout in milliseconds"
);
DEFINE_int32
(
max_retry
,
3
,
"Max retries(not including the first RPC)"
);
BRPCClient
::~
BRPCClient
()
{
Wait
();
}
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VoidMessage
>
response_guard
(
response
);
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
return
;
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
}
bool
BRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
(
[
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch_ptr
,
this
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
);
sendrecv
::
VariableMessage
request
;
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
request
,
response
,
done
);
});
req_count_
++
;
return
true
;
}
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
return
;
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
// framework::Variable* outvar = nullptr;
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
}
bool
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
(
[
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{});
req_count_
++
;
return
true
;
}
bool
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{});
req_count_
++
;
return
true
;
}
void
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
req_count_
++
;
}
void
BRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
req_count_
++
;
}
void
BRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
ChannelQueuePtr
BRPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
}
}
ChannelQueuePtr
q
(
new
framework
::
BlockingQueue
<
ChannelContextPtr
>
());
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"pooled"
;
options
.
connect_timeout_ms
=
100
;
options
.
timeout_ms
=
FLAGS_timeout_ms
/*milliseconds*/
;
options
.
max_retry
=
FLAGS_max_retry
;
for
(
int
i
=
0
;
i
<
FLAGS_brpc_channel_num
;
++
i
)
{
std
::
shared_ptr
<
ChannelContext
>
c
(
new
ChannelContext
());
if
(
c
->
channel
.
Init
(
ep
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
ERROR
)
<<
"Fail to initialize channel"
;
return
nullptr
;
}
c
->
stub
.
reset
(
new
sendrecv
::
SendRecvService_Stub
(
static_cast
<
google
::
protobuf
::
RpcChannel
*>
(
&
c
->
channel
)));
q
->
Push
(
c
);
}
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
channels_
[
ep
]
=
q
;
}
return
q
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/brpc_client.h
0 → 100644
浏览文件 @
d9de6b86
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <time.h>
#include <chrono> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace
paddle
{
namespace
operators
{
namespace
detail
{
struct
ChannelContext
{
brpc
::
Channel
channel
;
std
::
shared_ptr
<
sendrecv
::
SendRecvService_Stub
>
stub
;
};
typedef
std
::
shared_ptr
<
ChannelContext
>
ChannelContextPtr
;
typedef
std
::
shared_ptr
<
framework
::
BlockingQueue
<
ChannelContextPtr
>>
ChannelQueuePtr
;
class
BRPCClient
:
public
RPCClient
{
public:
BRPCClient
()
{}
virtual
~
BRPCClient
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
RPCClient
::
rpc_time_out
)
override
;
void
Wait
()
override
;
private:
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
private:
std
::
unordered_map
<
std
::
string
,
ChannelQueuePtr
>
channels_
;
// mutex for Wait client sync
std
::
mutex
sync_mutex_
;
std
::
condition_variable
sync_cond_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
DISABLE_COPY_AND_ASSIGN
(
BRPCClient
);
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/brpc_server.cc
0 → 100644
浏览文件 @
d9de6b86
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/brpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
namespace
sendrecv
{
typedef
std
::
unordered_map
<
std
::
string
,
paddle
::
operators
::
detail
::
RequestHandler
*>
HandlerMap
;
class
BRPCServiceImpl
:
public
SendRecvService
{
public:
explicit
BRPCServiceImpl
(
const
HandlerMap
&
rpc_call_map
)
:
request_send_h_
(
nullptr
),
request_get_h_
(
nullptr
),
request_prefetch_h_
(
nullptr
)
{
auto
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
detail
::
kRequestSend
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_send_h_
=
it
->
second
;
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
detail
::
kRequestSend
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_h_
=
it
->
second
;
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
detail
::
kRequestPrefetch
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_prefetch_h_
=
it
->
second
;
}
}
virtual
~
BRPCServiceImpl
()
{}
void
SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_send_h_
!=
nullptr
,
"RequestSend handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
paddle
::
framework
::
Scope
*
local_scope
=
request_send_h_
->
scope
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
std
::
string
varname
=
request
->
varname
();
if
(
!
request_send_h_
->
sync_mode
())
{
local_scope
=
&
request_send_h_
->
scope
()
->
NewScope
();
invar
=
local_scope
->
Var
(
varname
);
}
else
{
invar
=
local_scope
->
FindVar
(
varname
);
}
request_send_h_
->
Handle
(
varname
,
local_scope
,
invar
,
&
outvar
);
if
(
!
request_send_h_
->
sync_mode
())
{
request_send_h_
->
scope
()
->
DeleteScope
(
local_scope
);
}
}
void
GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_get_h_
!=
nullptr
,
"RequestGet handler should be registed first!"
);
}
void
PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_prefetch_h_
!=
nullptr
,
"kRequestPrefetch handler should be registed first!"
);
}
private:
paddle
::
operators
::
detail
::
RequestHandler
*
request_send_h_
;
paddle
::
operators
::
detail
::
RequestHandler
*
request_get_h_
;
paddle
::
operators
::
detail
::
RequestHandler
*
request_prefetch_h_
;
};
}
// namespace sendrecv
namespace
paddle
{
namespace
operators
{
namespace
detail
{
void
AsyncBRPCServer
::
StartServer
()
{
// Instance of your service.
sendrecv
::
BRPCServiceImpl
service_impl
(
rpc_call_map_
);
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// use brpc::SERVER_OWNS_SERVICE.
if
(
server_
.
AddService
(
&
service_impl
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
)
!=
0
)
{
LOG
(
FATAL
)
<<
"Fail to add service"
;
return
;
}
brpc
::
ServerOptions
options
;
options
.
idle_timeout_sec
=
idle_timeout_s_
;
options
.
max_concurrency
=
max_concurrency_
;
if
(
server_
.
Start
(
bind_address_
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
FATAL
)
<<
"Fail to start EchoServer"
<<
bind_address_
;
return
;
}
butil
::
EndPoint
ep
=
server_
.
listen_address
();
selected_port_
=
ep
.
port
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
}
condition_ready_
.
notify_all
();
server_
.
Join
();
}
void
AsyncBRPCServer
::
ShutDownImpl
()
{
server_
.
Stop
(
1000
);
}
void
AsyncBRPCServer
::
WaitServerReady
()
{
VLOG
(
3
)
<<
"AsyncGRPCServer is wait server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
VLOG
(
3
)
<<
"AsyncGRPCServer WaitSeverReady"
;
}
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/brpc_server.h
0 → 100644
浏览文件 @
d9de6b86
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <string>
#include "brpc/server.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
AsyncBRPCServer
final
:
public
RPCServer
{
public:
explicit
AsyncBRPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
RPCServer
(
address
,
client_num
),
ready_
(
0
)
{}
virtual
~
AsyncBRPCServer
()
{}
void
StartServer
()
override
;
void
WaitServerReady
()
override
;
private:
void
ShutDownImpl
()
override
;
brpc
::
Server
server_
;
static
constexpr
int
idle_timeout_s_
=
-
1
;
static
constexpr
int
max_concurrency_
=
0
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
d9de6b86
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
paddle/fluid/operators/detail/serde_test.cc
→
paddle/fluid/operators/detail/
grpc_
serde_test.cc
浏览文件 @
d9de6b86
文件已移动
paddle/fluid/operators/detail/macros.h
0 → 100644
浏览文件 @
d9de6b86
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#define RPCSERVER_T detail::AsyncGRPCServer
#define RPCCLIENT_T detail::GRPCClient
#else
#include "paddle/fluid/operators/detail/brpc_client.h"
#include "paddle/fluid/operators/detail/brpc_server.h"
#define RPCSERVER_T detail::AsyncBRPCServer
#define RPCCLIENT_T detail::BRPCClient
#endif
paddle/fluid/operators/detail/request_handler.h
浏览文件 @
d9de6b86
...
...
@@ -28,7 +28,6 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -38,6 +37,10 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
class
RPCServer
;
class
RequestHandler
{
...
...
paddle/fluid/operators/detail/request_handler_impl.cc
浏览文件 @
d9de6b86
...
...
@@ -16,15 +16,12 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/detail/request_handler_impl.h
浏览文件 @
d9de6b86
...
...
@@ -29,7 +29,6 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/detail/rpc_client.h
浏览文件 @
d9de6b86
...
...
@@ -26,6 +26,8 @@ namespace detail {
class
RPCClient
{
public:
RPCClient
()
{}
virtual
~
RPCClient
()
{}
virtual
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
...
...
paddle/fluid/operators/detail/
g
rpc_server_test.cc
→
paddle/fluid/operators/detail/rpc_server_test.cc
浏览文件 @
d9de6b86
...
...
@@ -17,15 +17,14 @@ limitations under the License. */
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
...
...
@@ -33,7 +32,7 @@ namespace detail = paddle::operators::detail;
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncG
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
...
...
@@ -112,20 +111,19 @@ void StartServer() {
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncG
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
std
::
bind
(
&
detail
::
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
server_thread
.
join
();
}
TEST
(
PREFETCH
,
CPU
)
{
g_req_handler
.
reset
(
new
detail
::
RequestPrefetchHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
...
...
paddle/fluid/operators/detail/send_recv.proto
浏览文件 @
d9de6b86
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
syntax
=
"proto3"
;
package
sendrecv
;
// option cc_generic_services = true;
service
SendRecvService
{
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
...
...
paddle/fluid/operators/detail/sendrecvop_utils.h
浏览文件 @
d9de6b86
...
...
@@ -32,16 +32,6 @@ namespace paddle {
namespace
operators
{
namespace
detail
{
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
static
int64_t
GetTimestamp
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
return
tp
.
tv_sec
*
1000
+
tp
.
tv_usec
/
1000
;
}
typedef
void
(
*
DestroyCallback
)(
void
*
);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
d9de6b86
...
...
@@ -19,9 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/rpc_client.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
@@ -45,7 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
rpc_client
->
Wait
();
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
d9de6b86
...
...
@@ -21,8 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
...
...
@@ -61,8 +60,8 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
...
...
@@ -78,9 +77,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail
::
RequestSendHandler
rpc_h
(
true
);
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
1
);
rpc_service
.
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
&
rpc_service
);
std
::
unique_ptr
<
detail
::
RPCServer
>
rpc_service
(
new
RPCSERVER_T
(
endpoint
,
1
));
rpc_service
->
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
rpc_service
.
get
());
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
...
...
@@ -90,12 +91,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
rpc_h
.
SetExecutor
(
&
executor
);
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
&
rpc_service
));
rpc_service
.
SetCond
(
detail
::
kRequestSend
);
std
::
bind
(
&
detail
::
RPCServer
::
StartServer
,
rpc_service
.
get
()));
rpc_service
->
SetCond
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
rpc_service
.
WaitBarrier
(
detail
::
kRequestSend
);
rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
.
ShutDown
();
rpc_service
->
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
server_thread
.
join
();
}
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
d9de6b86
...
...
@@ -19,7 +19,8 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -89,6 +90,12 @@ void ListenAndServOp::SavePort() const {
rpc_service_
->
SavePort
();
}
static
int64_t
GetTimestamp
()
{
struct
timeval
tp
;
gettimeofday
(
&
tp
,
NULL
);
return
tp
.
tv_sec
*
1000
+
tp
.
tv_usec
/
1000
;
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
...
...
@@ -127,7 +134,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
std
::
vector
<
size_t
>
parallel_blkids
;
parallel_blkids
.
push_back
(
1
);
double
ts
=
detail
::
GetTimestamp
();
double
ts
=
GetTimestamp
();
for
(
size_t
blkid
=
2
;
blkid
<
num_blocks
;
++
blkid
)
{
if
(
blkid
!=
static_cast
<
size_t
>
(
prefetch_block
->
ID
()))
{
if
(
program
->
Block
(
blkid
).
Parent
()
!=
last_parent_blkid
)
{
...
...
@@ -141,7 +148,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
}
ParallelExecuteBlocks
(
parallel_blkids
,
executor
,
optimize_prepared
,
program
,
recv_scope
);
VLOG
(
2
)
<<
"run all blocks spent "
<<
detail
::
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
detail
::
kRequestGet
);
...
...
@@ -235,8 +242,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
detail
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
d9de6b86
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/
grpc_client
.h"
#include "paddle/fluid/operators/detail/
macros
.h"
#include "paddle/fluid/operators/send_recv_util.h"
namespace
paddle
{
...
...
@@ -42,7 +42,7 @@ class PrefetchOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
d9de6b86
...
...
@@ -19,8 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
@@ -45,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
d9de6b86
...
...
@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
@@ -45,7 +45,7 @@ class SendBarrierOp : public framework::OperatorBase {
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
d9de6b86
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/
grpc_client
.h"
#include "paddle/fluid/operators/detail/
macros
.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -46,7 +46,7 @@ class SendOp : public framework::OperatorBase {
platform
::
RecordEvent
record_event
(
Type
(),
&
ctx
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
NeedSend
(
scope
,
ins
[
i
]))
{
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
d9de6b86
...
...
@@ -20,8 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
@@ -29,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_GRPC
#include "paddle/fluid/operators/send_recv_util.h"
#endif
USE_NO_KERNEL_OP
(
listen_and_serv
);
namespace
f
=
paddle
::
framework
;
...
...
@@ -37,7 +40,7 @@ namespace m = paddle::operators::math;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
string
=
paddle
::
string
;
std
::
unique_ptr
<
detail
::
AsyncG
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
void
StartServer
()
{
...
...
@@ -58,7 +61,7 @@ void StartServer() {
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncG
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
std
::
bind
(
&
detail
::
RPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
...
...
@@ -68,9 +71,9 @@ void StartServer() {
server_thread
.
join
();
}
TEST
(
SendNcclId
,
Grpc
Server
)
{
TEST
(
SendNcclId
,
RPC
Server
)
{
g_req_handler
.
reset
(
new
detail
::
RequestSendHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
...
...
@@ -87,8 +90,9 @@ TEST(SendNcclId, GrpcServer) {
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
detail
::
GRPCClient
>
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
->
Wait
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录