Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4fb7cc7f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4fb7cc7f
编写于
5月 31, 2018
作者:
G
gongweibao
提交者:
GitHub
5月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move sync_mode device ctx from grpc server (#10881)
上级
5870a6b4
变更
28
展开全部
隐藏空白更改
内联
并排
Showing
28 changed file
with
886 addition
and
553 deletion
+886
-553
benchmark/fluid/kube_gen_job.py
benchmark/fluid/kube_gen_job.py
+1
-1
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+3
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
...nference/analysis/data_flow_graph_to_fluid_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+3
-1
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
.../fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
+2
-0
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
...nference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+4
-2
paddle/fluid/inference/analysis/pass.h
paddle/fluid/inference/analysis/pass.h
+1
-0
paddle/fluid/inference/analysis/subgraph_splitter.h
paddle/fluid/inference/analysis/subgraph_splitter.h
+2
-0
paddle/fluid/inference/analysis/ut_helper.h
paddle/fluid/inference/analysis/ut_helper.h
+1
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+4
-1
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+2
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+2
-0
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+147
-225
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+22
-76
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+53
-34
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+127
-0
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+115
-0
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+64
-0
paddle/fluid/operators/detail/rpc_server.cc
paddle/fluid/operators/detail/rpc_server.cc
+113
-0
paddle/fluid/operators/detail/rpc_server.h
paddle/fluid/operators/detail/rpc_server.h
+91
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+2
-2
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+13
-8
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+63
-148
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+9
-22
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-0
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+33
-26
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+1
-0
未找到文件。
benchmark/fluid/kube_gen_job.py
浏览文件 @
4fb7cc7f
...
...
@@ -49,7 +49,7 @@ def parse_args():
parser
.
add_argument
(
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
parser
.
add_argument
(
'--rdma'
,
action
=
'store_t
ur
e'
,
help
=
'whether mount rdma libs'
)
'--rdma'
,
action
=
'store_t
ru
e'
,
help
=
'whether mount rdma libs'
)
parser
.
add_argument
(
'--disttype'
,
default
=
"pserver"
,
...
...
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
4fb7cc7f
...
...
@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
浏览文件 @
4fb7cc7f
...
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
}
}
//
analysis
}
//
inference
}
//
paddle
}
;
// namespace
analysis
}
;
// namespace
inference
}
;
// namespace
paddle
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
4fb7cc7f
...
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include
"paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include
<string>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
浏览文件 @
4fb7cc7f
...
...
@@ -19,6 +19,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
浏览文件 @
4fb7cc7f
...
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
}
}
// analysis
}
// inference
}
// paddle
}
//
namespace
analysis
}
//
namespace
inference
}
//
namespace
paddle
paddle/fluid/inference/analysis/helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
return
dic_
.
at
(
x
);
}
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
// NOLINT
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
return
dic_
.
at
(
hash
);
}
...
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE
(
float
);
}
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
std
::
string
>
dic_
;
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
// NOLINT
std
::
string
>
dic_
;
};
#undef SET_TYPE
...
...
paddle/fluid/inference/analysis/pass.h
浏览文件 @
4fb7cc7f
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
...
...
paddle/fluid/inference/analysis/subgraph_splitter.h
浏览文件 @
4fb7cc7f
...
...
@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/ut_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
...
@@ -58,7 +61,7 @@ class TRTConvertValidation {
public:
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1
<<
10
)
{
explicit
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1024
)
{
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
->
InitNetwork
();
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
4fb7cc7f
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
4fb7cc7f
...
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
bool
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
<<
" req_count_:"
<<
req_count_
;
if
(
req_count_
<=
0
)
{
return
true
;
}
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
4fb7cc7f
此差异已折叠。
点击以展开。
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
4fb7cc7f
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <map>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
...
...
@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
...
@@ -37,106 +41,48 @@ namespace paddle {
namespace
operators
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
framework
::
BlockingQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
AsyncGRPCServer
final
{
class
AsyncGRPCServer
final
:
public
RPCServer
{
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
),
ready_
(
0
)
{}
~
AsyncGRPCServer
()
{}
void
WaitServerReady
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
void
WaitCond
(
int
cond
);
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
RPCServer
(
address
,
client_num
),
ready_
(
0
)
{}
v
oid
Push
(
const
std
::
string
&
msg_name
)
{
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
))
;
}
v
irtual
~
AsyncGRPCServer
()
{}
void
WaitServerReady
()
override
;
void
StartServer
()
override
;
void
ShutDown
();
private:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc_name
,
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
);
protected:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
(
int
req_id
);
void
TryToRegisterNewGetOne
(
int
req_id
);
void
TryToRegisterNewPrefetchOne
(
int
req_id
);
void
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
);
void
ShutdownQueue
();
void
ShutDownImpl
()
override
;
private:
static
const
int
kSendReqsBufSize
=
100
;
static
const
int
kGetReqsBufSize
=
100
;
static
const
int
kPrefetchReqsBufSize
=
10
;
static
const
int
kRequestBufSize
=
100
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
RequestBase
*
send_reqs_
[
kSendReqsBufSize
];
RequestBase
*
get_reqs_
[
kGetReqsBufSize
];
RequestBase
*
prefetch_reqs_
[
kPrefetchReqsBufSize
];
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
framework
::
BlockingQueue
<
MessageWithName
>
var_get_queue_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
std
::
mutex
barrier_mutex_
;
mutable
int
barrier_cond_step_
;
std
::
condition_variable
barrier_condition_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_sends_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_gets_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_prefetchs_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>>
rpc_cq_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>>
rpc_threads_
;
std
::
map
<
std
::
string
,
std
::
vector
<
RequestBase
*>>
rpc_reqs_
;
};
};
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
4fb7cc7f
...
...
@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
detail
=
paddle
::
operators
::
detail
;
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
...
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
void
StartServer
()
{
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
...
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
rpc_service_
->
SetProgram
(
&
program
);
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
rpc_service_
->
SetDevCtx
(
&
ctx
);
rpc_service_
->
SetScope
(
&
scope
);
rpc_service_
->
SetExecutor
(
&
exe
);
g_req_handler
->
SetProgram
(
&
program
);
g_req_handler
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
g_req_handler
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
rpc_service_
->
RunSyncUpdate
();
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
TEST
(
PREFETCH
,
DISABLED_CPU
)
{
// start up a server instance backend
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
sleep
(
2
);
TEST
(
PREFETCH
,
CPU
)
{
g_req_handler
.
reset
(
new
detail
::
RequestPrefetchHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
auto
client
=
detail
::
RPCClient
::
GetInstance
();
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
rpc_service_
->
ShutDown
();
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
{
// create var on local scope
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
paddle/fluid/operators/detail/request_handler.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
class
RPCServer
;
class
RequestHandler
{
public:
explicit
RequestHandler
(
bool
sync_mode
)
:
sync_mode_
(
sync_mode
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
program_
(
nullptr
),
rpc_server_
(
nullptr
)
{}
virtual
~
RequestHandler
()
{}
// Set attributes.
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
// Used for async.
void
SetGradToPreparedCtx
(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
g
)
{
grad_to_prepared_ctx_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
bool
sync_mode
()
{
return
sync_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ExecutorPrepareContext
*
prefetch_ctx
()
{
return
prefetch_ctx_
.
get
();
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
Executor
*
executor
()
{
return
executor_
;
}
std
::
vector
<
framework
::
Variable
*>&
sparse_vars
()
{
return
sparse_vars_
;
}
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
=
0
;
protected:
const
bool
sync_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Scope
*
scope_
;
framework
::
ProgramDesc
*
program_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
// Used for async.
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
RPCServer
*
rpc_server_
;
std
::
mutex
sparse_var_mutex_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Async
if
(
!
sync_mode_
)
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestSend
);
}
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
sparse_var_mutex_
);
sparse_vars_
.
push_back
(
invar
);
}
}
return
true
;
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
varname
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestGet
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
return
true
;
}
// FETCH_BARRIER_MESSAGE
if
(
sync_mode_
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
varname
);
*
outvar
=
scope
->
FindVar
(
varname
);
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
.
get
(),
scope
);
return
true
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
explicit
RequestGetHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
void
RPCServer
::
ShutDown
()
{
LOG
(
INFO
)
<<
"RPCServer ShutDown "
;
ShutDownImpl
();
exit_flag_
=
true
;
barrier_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
SavePort
()
const
{
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
=
]
{
return
(
barrier_counter_
[
rpc_name
]
>=
client_num_
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"batch_barrier_:"
<<
barrier_counter_
[
rpc_name
];
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
int
b
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
}
VLOG
(
3
)
<<
"RPCServer IncreaseBatchBarrier "
<<
rpc_name
<<
", barrier_count:"
<<
b
<<
", fan_in"
<<
client_num_
;
if
(
b
>=
client_num_
)
{
barrier_cond_
.
notify_all
();
}
}
void
RPCServer
::
ResetBarrierCounter
()
{
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
)
{
rpc_call_map_
[
rpc_name
]
=
handler
;
rpc_thread_num_
[
rpc_name
]
=
thread_num
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer SetCond "
<<
rpc_name
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cur_cond_
=
rpc_cond_map_
[
rpc_name
];
}
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer WaitCond "
<<
rpc_name
;
int
cond
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond
=
rpc_cond_map_
[
rpc_name
];
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCServer
{
public:
explicit
RPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
cur_cond_
(
0
),
bind_address_
(
address
),
exit_flag_
(
false
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
WaitServerReady
()
=
0
;
void
ShutDown
();
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
void
SavePort
()
const
;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void
WaitBarrier
(
const
std
::
string
&
rpc_name
);
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
ResetBarrierCounter
();
protected:
virtual
void
ShutDownImpl
()
=
0
;
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
int
>
barrier_counter_
;
std
::
condition_variable
barrier_cond_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_cond_map_
;
std
::
atomic
<
int
>
cur_cond_
;
std
::
condition_variable
rpc_cond_
;
protected:
std
::
string
bind_address_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
const
int
client_num_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
friend
class
RequestHandler
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
4fb7cc7f
...
...
@@ -67,8 +67,8 @@ class VariableResponse {
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
// should call parse first.
framework
::
Variable
*
GetVar
()
{
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
...
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
true
);
detail
::
RequestSendHandler
rpc_h
(
true
);
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
1
);
rpc_service
.
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
&
rpc_service
);
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_
service
.
SetScope
(
scope
);
rpc_
service
.
SetDevCtx
(
&
dev_ctx
);
rpc_
service
.
SetProgram
(
&
empty_program
);
rpc_
service
.
SetExecutor
(
&
executor
);
rpc_
h
.
SetScope
(
scope
);
rpc_
h
.
SetDevCtx
(
&
dev_ctx
);
rpc_
h
.
SetProgram
(
&
empty_program
);
rpc_
h
.
SetExecutor
(
&
executor
);
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
&
rpc_service
));
rpc_service
.
SetCond
(
0
);
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
&
rpc_service
));
rpc_service
.
SetCond
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
auto
recv
=
rpc_service
.
Get
(
);
rpc_service
.
WaitBarrier
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
.
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
)
{
service
->
RunSyncUpdate
();
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
)
{
service
->
StartServer
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
...
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
}
std
::
atomic_int
ListenAndServOp
::
selected_port_
{
0
};
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
void
ListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
...
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void
ListenAndServOp
::
SavePort
()
const
{
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_
=
rpc_service_
->
GetSelectedPort
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
.
load
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
ListenAndServOp
::
WaitServerReady
()
{
while
(
selected_port_
.
load
()
==
0
)
{
}
rpc_service_
->
SavePort
();
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
...
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared
.
begin
(),
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
bool
exit_flag
=
false
;
rpc_service_
->
ResetBarrierCounter
()
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
()
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
size_t
recv_var_cnt
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
recv_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
}
}
}
if
(
exit_flag
)
{
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
ShutDown
();
rpc_service_
->
SetCond
(
detail
::
kRequestSend
);
rpc_service_
->
WaitBarrier
(
detail
::
kRequestSend
);
if
(
rpc_service_
->
IsExit
())
{
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
break
;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
...
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
rpc_service_
->
SetCond
(
1
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_
->
WaitClientGet
(
fan_in
);
sparse_vars
.
clear
();
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
detail
::
kRequestGet
);
rpc_service_
->
ResetBarrierCounter
();
}
// while(true)
}
static
void
AsyncUpdateThread
(
const
std
::
string
&
var_name
,
const
bool
&
exit_flag
,
const
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
&
queue
,
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
)
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" started"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
queue
->
Pop
();
if
(
SignalHandler
::
IsProgramExit
())
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" exit"
;
break
;
}
auto
recv_var_name
=
v
.
first
;
VLOG
(
4
)
<<
"async update "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
auto
fs
=
framework
::
Async
([
var_name
,
&
executor
,
&
v
,
prepared
]
{
try
{
executor
->
RunPreparedContext
(
prepared
,
v
.
second
->
GetMutableLocalScope
());
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
fs
.
wait
();
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
ReceivedQueue
>>
grad_to_queue
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
...
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
queue
=
std
::
make_shared
<
detail
::
ReceivedQueue
>
();
grad_to_queue
[
pieces
[
0
]]
=
queue
;
// record blocking queue in SignalHandler
SignalHandler
::
RegisterBlockingQueue
(
queue
);
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
...
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
bool
exit_flag
=
false
;
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
VLOG
(
3
)
<<
"start async optimize threads"
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
iter
=
grad_to_queue
.
begin
();
iter
!=
grad_to_queue
.
end
();
iter
++
)
{
std
::
string
grad_name
=
iter
->
first
;
VLOG
(
3
)
<<
"create async update thread for "
<<
grad_name
;
fs
.
push_back
(
framework
::
AsyncIO
([
grad_name
,
&
exit_flag
,
&
executor
,
&
grad_to_queue
,
&
grad_to_prepared_ctx
]()
{
AsyncUpdateThread
(
grad_name
,
exit_flag
,
grad_to_queue
[
grad_name
],
executor
,
grad_to_prepared_ctx
[
grad_name
].
get
());
}));
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
while
(
true
)
{
if
(
rpc_service_
->
IsExit
())
{
LOG
(
INFO
)
<<
"get exit!rpc_processor break!"
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
grad_to_queue
[
recv_var_name
]
->
Push
(
v
);
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
sleep
(
1
);
}
// while(true)
}
static
void
FillRequestCtx
(
detail
::
RequestHandler
*
h
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
detail
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
std
::
move
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
(
prefetch_ctx
)));
h
->
SetRPCServer
(
rpc_server
);
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
...
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
detail
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// prepare rpc_service
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetProgram
(
program
);
rpc_service_
->
SetExecutor
(
&
executor
);
// prepare for prefetch
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prefetch_prepared
));
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
prefetch_prepared
.
release
(),
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
// Write to a file of server selected port for python use.
std
::
string
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.selected_port"
,
static_cast
<
int
>
(
::
getpid
()));
SavePort
();
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
...
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
bool
SignalHandler
::
program_exit_flag_
=
false
;
SignalHandler
::
BlockingQueueSet
SignalHandler
::
blocking_queue_set_
{};
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
program_exit_flag_
=
true
;
// awake all blocking queues
for
(
BlockingQueueSet
::
iterator
iter
=
blocking_queue_set_
.
begin
();
iter
!=
blocking_queue_set_
.
end
();
iter
++
)
{
iter
->
get
()
->
Push
(
std
::
make_pair
(
std
::
string
(
LISTEN_TERMINATE_MESSAGE
),
nullptr
));
}
exit
(
EXIT_SUCCESS
);
}
void
SignalHandler
::
RegisterBlockingQueue
(
BlockingQueue
&
queue
)
{
blocking_queue_set_
.
insert
(
queue
);
exit
(
0
);
}
}
// namespace operators
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
4fb7cc7f
...
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -31,7 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
);
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
...
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void
SavePort
()
const
;
void
WaitServerReady
();
int
GetSelectedPort
()
{
return
selected_port_
;
}
int
GetSelectedPort
()
{
return
rpc_service_
->
GetSelectedPort
();
}
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
static
void
ResetPort
()
{
selected_port_
=
0
;
}
protected:
mutable
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
// FIXME(wuyi): it's static so that the operator can be cloned.
static
std
::
atomic_int
selected_port_
;
};
class
SignalHandler
{
public:
typedef
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
BlockingQueue
;
typedef
std
::
unordered_set
<
BlockingQueue
>
BlockingQueueSet
;
public:
static
void
StopAndExit
(
int
signal_num
);
static
void
RegisterBlockingQueue
(
BlockingQueue
&
);
static
inline
bool
IsProgramExit
()
{
return
program_exit_flag_
;
}
private:
static
bool
program_exit_flag_
;
static
BlockingQueueSet
blocking_queue_set_
;
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
4fb7cc7f
...
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
string
=
paddle
::
string
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
void
StartServer
(
std
::
atomic
<
bool
>*
initialized
)
{
void
StartServer
()
{
f
::
Scope
scope
;
p
::
CPUPlace
place
;
scope
.
Var
(
NCCL_ID_VARNAME
);
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
true
));
f
::
ProgramDesc
empty_program
;
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_service
->
SetScope
(
&
scope
);
rpc_service
->
SetDevCtx
(
&
dev_ctx
);
rpc_service
->
SetProgram
(
&
empty_program
);
rpc_service
->
SetExecutor
(
&
executor
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetDevCtx
(
&
dev_ctx
);
g_req_handler
->
SetProgram
(
&
empty_program
);
g_req_handler
->
SetExecutor
(
&
executor
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestSend
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service
.
get
()));
*
initialized
=
true
;
rpc_service
->
SetCond
(
0
);
auto
recv
=
rpc_service
->
Get
();
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
rpc_service
->
ShutDown
();
g_
rpc_service
->
ShutDown
();
server_thread
.
join
();
}
TEST
(
SendNcclId
,
DISABLED_Normal
)
{
std
::
atomic
<
bool
>
initialized
{
false
};
std
::
thread
server_thread
(
StartServer
,
&
initialized
);
while
(
!
initialized
)
{
}
// wait server to start
// sleep(2);
rpc_service
->
WaitServerReady
();
TEST
(
SendNcclId
,
GrpcServer
)
{
g_req_handler
.
reset
(
new
detail
::
RequestSendHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
f
::
Scope
scope
;
p
::
CPUPlace
place
;
...
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
// var->SetType(f::proto::VarType_Type_RAW);
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
p
::
dynload
::
ncclGetUniqueId
(
id
);
int
port
=
rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
server_thread
.
join
();
auto
*
ptr
=
rpc_service
.
release
(
);
delete
ptr
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
)
;
}
paddle/fluid/platform/nccl_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <vector>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录