Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4fb7cc7f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4fb7cc7f
编写于
5月 31, 2018
作者:
G
gongweibao
提交者:
GitHub
5月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move sync_mode device ctx from grpc server (#10881)
上级
5870a6b4
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
886 addition
and
553 deletion
+886
-553
benchmark/fluid/kube_gen_job.py
benchmark/fluid/kube_gen_job.py
+1
-1
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+3
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
...nference/analysis/data_flow_graph_to_fluid_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+3
-1
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
.../fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
+2
-0
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
...nference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+4
-2
paddle/fluid/inference/analysis/pass.h
paddle/fluid/inference/analysis/pass.h
+1
-0
paddle/fluid/inference/analysis/subgraph_splitter.h
paddle/fluid/inference/analysis/subgraph_splitter.h
+2
-0
paddle/fluid/inference/analysis/ut_helper.h
paddle/fluid/inference/analysis/ut_helper.h
+1
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+4
-1
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+2
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+2
-0
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+147
-225
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+22
-76
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+53
-34
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+127
-0
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+115
-0
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+64
-0
paddle/fluid/operators/detail/rpc_server.cc
paddle/fluid/operators/detail/rpc_server.cc
+113
-0
paddle/fluid/operators/detail/rpc_server.h
paddle/fluid/operators/detail/rpc_server.h
+91
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+2
-2
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+13
-8
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+63
-148
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+9
-22
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-0
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+33
-26
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+1
-0
未找到文件。
benchmark/fluid/kube_gen_job.py
浏览文件 @
4fb7cc7f
...
@@ -49,7 +49,7 @@ def parse_args():
...
@@ -49,7 +49,7 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--rdma'
,
action
=
'store_t
ur
e'
,
help
=
'whether mount rdma libs'
)
'--rdma'
,
action
=
'store_t
ru
e'
,
help
=
'whether mount rdma libs'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--disttype'
,
'--disttype'
,
default
=
"pserver"
,
default
=
"pserver"
,
...
...
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
4fb7cc7f
...
@@ -21,7 +21,10 @@ limitations under the License. */
...
@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <deque>
#include <stack>
#include <stack>
#include <string>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
浏览文件 @
4fb7cc7f
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
}
}
}
//
analysis
}
;
// namespace
analysis
}
//
inference
}
;
// namespace
inference
}
//
paddle
}
;
// namespace
paddle
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
4fb7cc7f
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include
"paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include
<string>
#include <vector>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
浏览文件 @
4fb7cc7f
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#pragma once
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass.h"
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
浏览文件 @
4fb7cc7f
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
}
}
}
// analysis
}
//
namespace
analysis
}
// inference
}
//
namespace
inference
}
// paddle
}
//
namespace
paddle
paddle/fluid/inference/analysis/helper.h
浏览文件 @
4fb7cc7f
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
return
dic_
.
at
(
x
);
return
dic_
.
at
(
x
);
}
}
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
// NOLINT
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
return
dic_
.
at
(
hash
);
return
dic_
.
at
(
hash
);
}
}
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE
(
float
);
SET_TYPE
(
float
);
}
}
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
std
::
string
>
dic_
;
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
// NOLINT
std
::
string
>
dic_
;
};
};
#undef SET_TYPE
#undef SET_TYPE
...
...
paddle/fluid/inference/analysis/pass.h
浏览文件 @
4fb7cc7f
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <iosfwd>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
...
...
paddle/fluid/inference/analysis/subgraph_splitter.h
浏览文件 @
4fb7cc7f
...
@@ -18,6 +18,8 @@ limitations under the License. */
...
@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/ut_helper.h
浏览文件 @
4fb7cc7f
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
4fb7cc7f
...
@@ -19,6 +19,9 @@ limitations under the License. */
...
@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
@@ -58,7 +61,7 @@ class TRTConvertValidation {
...
@@ -58,7 +61,7 @@ class TRTConvertValidation {
public:
public:
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1
<<
10
)
{
explicit
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1024
)
{
// create engine.
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
4fb7cc7f
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
4fb7cc7f
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
}
bool
RPCClient
::
Wait
()
{
bool
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
<<
" req_count_:"
<<
req_count_
;
if
(
req_count_
<=
0
)
{
if
(
req_count_
<=
0
)
{
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
4fb7cc7f
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <limits>
#include <string>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
#include "paddle/fluid/operators/detail/grpc_server.h"
DEFINE_int32
(
rpc_server_handle_send_threads
,
20
,
using
::
grpc
::
ServerAsyncResponseWriter
;
"Number of threads used to handle send at rpc server."
);
DEFINE_int32
(
rpc_server_handle_get_threads
,
20
,
"Number of threads used to handle get at rpc server."
);
DEFINE_int32
(
rpc_server_handle_prefetch_threads
,
1
,
"Number of threads used to handle prefetch at rpc server."
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
...
@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
class
RequestBase
{
public:
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
const
platform
::
DeviceContext
*
dev_ctx
)
RequestHandler
*
request_handler
,
int
req_id
)
:
service_
(
service
),
:
service_
(
service
),
cq_
(
cq
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
request_handler_
(
request_handler
),
req_id_
(
req_id
)
{
PADDLE_ENFORCE
(
cq_
);
PADDLE_ENFORCE
(
cq_
);
}
}
virtual
~
RequestBase
()
{}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
virtual
void
Process
()
=
0
;
CallStatus
Status
()
{
return
status_
;
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
virtual
std
::
string
GetReqName
()
=
0
;
assert
(
false
);
return
""
;
}
protected:
protected:
::
grpc
::
ServerContext
ctx_
;
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
RequestHandler
*
request_handler_
;
int
req_id_
;
};
};
class
RequestSend
final
:
public
RequestBase
{
class
RequestSend
final
:
public
RequestBase
{
public:
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
queue_
(
queue
),
request_handler
->
dev_ctx
(),
responder_
(
&
ctx_
),
!
request_handler
->
sync_mode
()));
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
...
@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
virtual
~
RequestSend
()
{}
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
std
::
string
varname
=
GetReqName
();
VLOG
(
3
)
<<
"RequestSend var_name:"
<<
varname
;
virtual
void
Process
()
{
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
var_name
=
GetReqName
();
auto
invar
=
request_
->
GetVar
();
VLOG
(
3
)
<<
"RequestSend "
<<
var_name
;
framework
::
Variable
*
outvar
=
nullptr
;
queue_
->
Push
(
std
::
make_pair
(
var_name
,
request_
));
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
...
@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
protected:
protected:
sendrecv
::
VoidMessage
reply_
;
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
int
req_id_
;
};
};
class
RequestGet
final
:
public
RequestBase
{
class
RequestGet
final
:
public
RequestBase
{
public:
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
),
req_id_
(
req_id
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
}
virtual
~
RequestGet
()
{}
virtual
~
RequestGet
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// proc request.
// proc request.
std
::
string
var_name
=
request_
.
varname
();
std
::
string
varname
=
request_
.
varname
();
VLOG
(
3
)
<<
"RequestGet "
<<
var_name
;
VLOG
(
3
)
<<
"RequestGet "
<<
varname
;
auto
*
var
=
scope_
->
FindVar
(
var_name
);
auto
scope
=
request_handler_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
}
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
sendrecv
::
VariableMessage
msg
;
MessageWithName
msg_with_name
=
std
::
make_pair
(
var_name
,
msg
);
queue_
->
Push
(
msg_with_name
);
}
}
}
protected:
protected:
sendrecv
::
VariableMessage
request_
;
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
BlockingQueue
<
MessageWithName
>*
queue_
;
int
req_id_
;
};
};
class
RequestPrefetch
final
:
public
RequestBase
{
class
RequestPrefetch
final
:
public
RequestBase
{
public:
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
::
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
RequestHandler
*
request_handler
,
int
req_id
)
const
platform
::
DeviceContext
*
dev_ctx
,
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
local_scope_
(
nullptr
)
{
executor_
(
executor
),
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
program_
(
program
),
request_handler
->
dev_ctx
(),
true
));
prefetch_ctx_
(
prefetch_ctx
),
req_id_
(
req_id
)
{
// prefetch always create a new sub scope
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
}
virtual
~
RequestPrefetch
()
{}
virtual
~
RequestPrefetch
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// prefetch process...
// prefetch process...
std
::
string
varname
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
varname
;
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
std
::
string
var_name
=
request_
->
OutVarname
();
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
VLOG
(
3
)
<<
"RequestPrefetch "
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
request_
->
GetMutableLocalScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
local_scope
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
status_
=
FINISH
;
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
...
@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
std
::
shared_ptr
<
VariableResponse
>
request_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
local_scope_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
req_id_
;
};
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
int
fetch_barriers
=
0
;
while
(
fetch_barriers
<
count
)
{
auto
msg
=
var_get_queue_
.
Pop
();
if
(
msg
.
first
==
FETCH_BARRIER_MESSAGE
)
{
fetch_barriers
++
;
}
}
}
void
AsyncGRPCServer
::
WaitServerReady
()
{
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
3
)
<<
"AsyncGRPCServer is wait server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
VLOG
(
3
)
<<
"AsyncGRPCServer WaitSeverReady"
;
}
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
StartServer
()
{
::
grpc
::
ServerBuilder
builder
;
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
(),
builder
.
AddListeningPort
(
bind_
address_
,
::
grpc
::
InsecureServerCredentials
(),
&
selected_port_
);
&
selected_port_
);
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
for
(
auto
t
:
rpc_call_map_
)
{
cq_get_
=
builder
.
AddCompletionQueue
(
);
rpc_cq_
[
t
.
first
].
reset
(
builder
.
AddCompletionQueue
().
release
()
);
cq_prefetch_
=
builder
.
AddCompletionQueue
();
}
server_
=
builder
.
BuildAndStart
();
server_
=
builder
.
BuildAndStart
();
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
LOG
(
INFO
)
<<
"Server listening on "
<<
bind_
address_
<<
" selected port: "
<<
selected_port_
;
<<
" selected port: "
<<
selected_port_
;
std
::
function
<
void
(
int
)
>
send_register
=
std
::
bind
(
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
f
=
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
,
std
::
placeholders
::
_1
);
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewOne
,
this
,
std
::
function
<
void
(
int
)
>
get_register
=
std
::
bind
(
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
);
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
prefetch_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
,
std
::
placeholders
::
_1
);
for
(
int
i
=
0
;
i
<
kSendReqsBufSize
;
++
i
)
{
for
(
auto
&
t
:
rpc_call_map_
)
{
TryToRegisterNewSendOne
(
i
);
auto
&
rpc_name
=
t
.
first
;
}
auto
&
cq
=
rpc_cq_
[
rpc_name
];
for
(
int
i
=
0
;
i
<
kGetReqsBufSize
;
++
i
)
{
auto
threadnum
=
rpc_thread_num_
[
rpc_name
];
TryToRegisterNewGetOne
(
i
);
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
}
for
(
int
i
=
0
;
i
<
kPrefetchReqsBufSize
;
++
i
)
{
TryToRegisterNewPrefetchOne
(
i
);
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
reqs
.
reserve
(
kRequestBufSize
);
t_sends_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
for
(
int
i
=
0
;
i
<
kRequestBufSize
;
i
++
)
{
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
TryToRegisterNewOne
(
rpc_name
,
i
);
}
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
.
emplace_back
(
for
(
int
i
=
0
;
i
<
threadnum
;
i
++
)
{
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
rpc_threads_
[
rpc_name
].
emplace_back
(
new
std
::
thread
(
std
::
bind
(
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq
.
get
(),
rpc_name
,
f
)));
}
VLOG
(
3
)
<<
t
.
first
<<
" creates threads!"
;
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
}
t_prefetchs_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
}
}
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
ready_
=
1
;
}
}
condition_ready_
.
notify_all
();
condition_ready_
.
notify_all
();
// wait server
// wait server
server_
->
Wait
();
server_
->
Wait
();
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
[
i
]
->
join
();
for
(
auto
&
t
:
rpc_threads_
)
{
}
auto
&
threads
=
t
.
second
;
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
++
i
)
{
t_gets_
[
i
]
->
join
();
threads
[
i
]
->
join
();
}
VLOG
(
3
)
<<
t
.
first
<<
" threads ends!"
;
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
}
t_prefetchs_
[
i
]
->
join
();
}
}
}
}
void
AsyncGRPCServer
::
ShutdownQueue
()
{
void
AsyncGRPCServer
::
ShutdownQueue
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
for
(
auto
&
t
:
rpc_cq_
)
{
cq_send_
->
Shutdown
();
t
.
second
->
Shutdown
();
cq_get_
->
Shutdown
()
;
VLOG
(
3
)
<<
t
.
first
<<
" shutdown!"
;
cq_prefetch_
->
Shutdown
();
}
}
}
// This URL explains why shutdown is complicate:
void
AsyncGRPCServer
::
ShutDownImpl
()
{
void
AsyncGRPCServer
::
ShutDown
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
is_shut_down_
=
true
;
is_shut_down_
=
true
;
ShutdownQueue
();
ShutdownQueue
();
VLOG
(
3
)
<<
"server_ shutdown!"
;
server_
->
Shutdown
();
server_
->
Shutdown
();
}
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
(
int
i
)
{
void
AsyncGRPCServer
::
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
return
;
}
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
sync_mode_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
,
i
);
send_reqs_
[
i
]
=
static_cast
<
RequestBase
*>
(
send
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
req_id
)
{
VLOG
(
4
)
<<
"register send rpc_name:"
<<
rpc_name
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
<<
", handler:"
<<
rpc_call_map_
[
kRequestSend
];
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
return
;
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
auto
&
cq
=
rpc_cq_
[
rpc_name
];
RequestBase
*
b
=
nullptr
;
if
(
rpc_name
==
kRequestSend
)
{
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
{
PADDLE_ENFORCE
(
false
,
"not surpported rpc"
);
}
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
&
var_get_queue_
,
req_id
);
get_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
get
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
req_id
)
{
reqs
[
req_id
]
=
b
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
(),
req_id
);
prefetch_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
prefetch
);
VLOG
(
4
)
<<
"Create Request
Prefetch status:"
<<
prefetch
->
Status
();
VLOG
(
4
)
<<
"Create Request
Send status:"
<<
b
->
Status
();
}
}
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq
_name
,
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc
_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
)
{
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
)
{
void
*
tag
=
NULL
;
void
*
tag
=
NULL
;
bool
ok
=
false
;
bool
ok
=
false
;
while
(
true
)
{
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest
for "
<<
cq_name
<<
" wait N
ext"
;
VLOG
(
3
)
<<
"HandleRequest
"
<<
rpc_name
<<
" wait n
ext"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue
shutdown!"
;
LOG
(
INFO
)
<<
"CompletionQueue "
<<
rpc_name
<<
"
shutdown!"
;
break
;
break
;
}
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" get Next"
;
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
if
(
sync_mode_
)
{
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
// FIXME(typhoonzero): de-couple the barriers with recv_op
VLOG
(
3
)
<<
"HandleRequest "
<<
rpc_name
<<
", req_id:"
<<
req_id
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
<<
" get next"
;
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" after WaitCond"
;
}
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
RequestBase
*
base
=
nullptr
;
RequestBase
*
base
=
nullptr
;
{
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
PADDLE_ENFORCE
(
req_id
>=
0
&&
req_id
<
kRequestBufSize
);
if
(
cq_name
==
"cq_get"
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
base
=
get_reqs_
[
req_id
];
base
=
reqs
[
req_id
];
}
else
if
(
cq_name
==
"cq_send"
)
{
base
=
send_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_prefetch"
)
{
base
=
prefetch_reqs_
[
req_id
];
}
}
}
// reference:
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
LOG
(
WARNING
)
<<
"completion queue:"
<<
rpc_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
(
req_id
);
TryToRegisterNewOne
(
r
pc_name
,
r
eq_id
);
delete
base
;
delete
base
;
continue
;
continue
;
}
}
VLOG
(
3
)
<<
"queue id:"
<<
rpc_name
<<
", req_id:"
<<
req_id
<<
", status:"
<<
base
->
Status
();
switch
(
base
->
Status
())
{
switch
(
base
->
Status
())
{
case
PROCESS
:
{
case
PROCESS
:
{
base
->
Process
();
base
->
Process
();
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
break
;
break
;
}
}
case
FINISH
:
{
case
FINISH
:
{
TryToRegisterNewOne
(
req_id
);
TryToRegisterNewOne
(
rpc_name
,
req_id
);
VLOG
(
4
)
<<
cq_name
<<
" FINISH status:"
<<
base
->
Status
();
delete
base
;
delete
base
;
break
;
break
;
}
}
...
@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
...
@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
}
}
}
}
void
AsyncGRPCServer
::
WaitCond
(
int
cond
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
barrier_cond_step_
==
cond
;
});
}
void
AsyncGRPCServer
::
SetCond
(
int
cond
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_cond_step_
=
cond
;
}
barrier_condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace detail
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
4fb7cc7f
...
@@ -14,6 +14,8 @@ limitations under the License. */
...
@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#pragma once
#include <map>
#include <set>
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <utility>
#include <utility>
...
@@ -28,6 +30,8 @@ limitations under the License. */
...
@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
@@ -37,106 +41,48 @@ namespace paddle {
...
@@ -37,106 +41,48 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
detail
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
framework
::
BlockingQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
RequestBase
;
class
AsyncGRPCServer
final
{
class
AsyncGRPCServer
final
:
public
RPCServer
{
public:
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
address_
(
address
),
sync_mode_
(
sync_mode
),
ready_
(
0
)
{}
:
RPCServer
(
address
,
client_num
),
ready_
(
0
)
{}
~
AsyncGRPCServer
()
{}
void
WaitServerReady
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
void
WaitCond
(
int
cond
);
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
v
oid
Push
(
const
std
::
string
&
msg_name
)
{
v
irtual
~
AsyncGRPCServer
()
{}
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
))
;
void
WaitServerReady
()
override
;
}
void
StartServer
()
override
;
void
ShutDown
();
private:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc_name
,
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
);
protected:
void
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
);
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
(
int
req_id
);
void
TryToRegisterNewGetOne
(
int
req_id
);
void
TryToRegisterNewPrefetchOne
(
int
req_id
);
void
ShutdownQueue
();
void
ShutdownQueue
();
void
ShutDownImpl
()
override
;
private:
private:
static
const
int
kSendReqsBufSize
=
100
;
static
const
int
kRequestBufSize
=
100
;
static
const
int
kGetReqsBufSize
=
100
;
static
const
int
kPrefetchReqsBufSize
=
10
;
std
::
mutex
cq_mutex_
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
RequestBase
*
send_reqs_
[
kSendReqsBufSize
];
RequestBase
*
get_reqs_
[
kGetReqsBufSize
];
RequestBase
*
prefetch_reqs_
[
kPrefetchReqsBufSize
];
GrpcService
::
AsyncService
service_
;
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
framework
::
BlockingQueue
<
MessageWithName
>
var_get_queue_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
// condition of the sub program
std
::
mutex
barrier_mutex_
;
std
::
mutex
barrier_mutex_
;
mutable
int
barrier_cond_step_
;
mutable
int
barrier_cond_step_
;
std
::
condition_variable
barrier_condition_
;
std
::
condition_variable
barrier_condition_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_sends_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_gets_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_prefetchs_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
std
::
mutex
mutex_ready_
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
int
ready_
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>>
rpc_cq_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>>
rpc_threads_
;
std
::
map
<
std
::
string
,
std
::
vector
<
RequestBase
*>>
rpc_reqs_
;
};
};
};
// namespace detail
};
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
4fb7cc7f
...
@@ -24,13 +24,16 @@ limitations under the License. */
...
@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace
framework
=
paddle
::
framework
;
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
detail
=
paddle
::
operators
::
detail
;
USE_OP
(
lookup_table
);
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
auto
root_block
=
program
->
MutableBlock
(
0
);
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
void
StartServer
()
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
rpc_service_
->
SetProgram
(
&
program
);
g_req_handler
->
SetProgram
(
&
program
);
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
g_req_handler
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
rpc_service_
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetDevCtx
(
&
ctx
);
rpc_service_
->
SetScope
(
&
scope
);
g_req_handler
->
SetScope
(
&
scope
);
rpc_service_
->
SetExecutor
(
&
exe
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
rpc_service_
->
RunSyncUpdate
();
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
}
TEST
(
PREFETCH
,
DISABLED_CPU
)
{
TEST
(
PREFETCH
,
CPU
)
{
// start up a server instance backend
g_req_handler
.
reset
(
new
detail
::
RequestPrefetchHandler
(
true
));
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
sleep
(
2
);
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
{
int64_t
rows_numel
=
5
;
// create var on local scope
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
int64_t
rows_numel
=
5
;
std
::
string
in_var_name
(
"ids"
);
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
out_var_name
(
"out"
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
auto
client
=
detail
::
RPCClient
::
GetInstance
();
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
out_var_name
);
client
.
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
rpc_service_
->
ShutDown
();
}
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
}
paddle/fluid/operators/detail/request_handler.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
class
RPCServer
;
class
RequestHandler
{
public:
explicit
RequestHandler
(
bool
sync_mode
)
:
sync_mode_
(
sync_mode
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
program_
(
nullptr
),
rpc_server_
(
nullptr
)
{}
virtual
~
RequestHandler
()
{}
// Set attributes.
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
// Used for async.
void
SetGradToPreparedCtx
(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
g
)
{
grad_to_prepared_ctx_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
bool
sync_mode
()
{
return
sync_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ExecutorPrepareContext
*
prefetch_ctx
()
{
return
prefetch_ctx_
.
get
();
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
Executor
*
executor
()
{
return
executor_
;
}
std
::
vector
<
framework
::
Variable
*>&
sparse_vars
()
{
return
sparse_vars_
;
}
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
=
0
;
protected:
const
bool
sync_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Scope
*
scope_
;
framework
::
ProgramDesc
*
program_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
// Used for async.
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
RPCServer
*
rpc_server_
;
std
::
mutex
sparse_var_mutex_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Async
if
(
!
sync_mode_
)
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestSend
);
}
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
sparse_var_mutex_
);
sparse_vars_
.
push_back
(
invar
);
}
}
return
true
;
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
varname
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestGet
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
return
true
;
}
// FETCH_BARRIER_MESSAGE
if
(
sync_mode_
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
varname
);
*
outvar
=
scope
->
FindVar
(
varname
);
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
.
get
(),
scope
);
return
true
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
explicit
RequestGetHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
void
RPCServer
::
ShutDown
()
{
LOG
(
INFO
)
<<
"RPCServer ShutDown "
;
ShutDownImpl
();
exit_flag_
=
true
;
barrier_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
SavePort
()
const
{
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
=
]
{
return
(
barrier_counter_
[
rpc_name
]
>=
client_num_
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"batch_barrier_:"
<<
barrier_counter_
[
rpc_name
];
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
int
b
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
}
VLOG
(
3
)
<<
"RPCServer IncreaseBatchBarrier "
<<
rpc_name
<<
", barrier_count:"
<<
b
<<
", fan_in"
<<
client_num_
;
if
(
b
>=
client_num_
)
{
barrier_cond_
.
notify_all
();
}
}
void
RPCServer
::
ResetBarrierCounter
()
{
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
)
{
rpc_call_map_
[
rpc_name
]
=
handler
;
rpc_thread_num_
[
rpc_name
]
=
thread_num
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer SetCond "
<<
rpc_name
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cur_cond_
=
rpc_cond_map_
[
rpc_name
];
}
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer WaitCond "
<<
rpc_name
;
int
cond
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond
=
rpc_cond_map_
[
rpc_name
];
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCServer
{
public:
explicit
RPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
cur_cond_
(
0
),
bind_address_
(
address
),
exit_flag_
(
false
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
WaitServerReady
()
=
0
;
void
ShutDown
();
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
void
SavePort
()
const
;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void
WaitBarrier
(
const
std
::
string
&
rpc_name
);
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
ResetBarrierCounter
();
protected:
virtual
void
ShutDownImpl
()
=
0
;
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
int
>
barrier_counter_
;
std
::
condition_variable
barrier_cond_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_cond_map_
;
std
::
atomic
<
int
>
cur_cond_
;
std
::
condition_variable
rpc_cond_
;
protected:
std
::
string
bind_address_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
const
int
client_num_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
friend
class
RequestHandler
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
4fb7cc7f
...
@@ -67,8 +67,8 @@ class VariableResponse {
...
@@ -67,8 +67,8 @@ class VariableResponse {
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
// should call parse first.
// should call parse first.
framework
::
Variable
*
GetVar
()
{
framework
::
Variable
*
GetVar
()
{
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
4fb7cc7f
...
@@ -23,6 +23,7 @@ limitations under the License. */
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
true
);
detail
::
RequestSendHandler
rpc_h
(
true
);
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
1
);
rpc_service
.
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
&
rpc_service
);
framework
::
ProgramDesc
empty_program
;
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_
service
.
SetScope
(
scope
);
rpc_
h
.
SetScope
(
scope
);
rpc_
service
.
SetDevCtx
(
&
dev_ctx
);
rpc_
h
.
SetDevCtx
(
&
dev_ctx
);
rpc_
service
.
SetProgram
(
&
empty_program
);
rpc_
h
.
SetProgram
(
&
empty_program
);
rpc_
service
.
SetExecutor
(
&
executor
);
rpc_
h
.
SetExecutor
(
&
executor
);
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
&
rpc_service
));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
&
rpc_service
));
rpc_service
.
SetCond
(
0
);
rpc_service
.
SetCond
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
auto
recv
=
rpc_service
.
Get
(
);
rpc_service
.
WaitBarrier
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
.
ShutDown
();
rpc_service
.
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
VLOG
(
3
)
<<
"rpc server stopped"
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
4fb7cc7f
...
@@ -19,14 +19,16 @@ limitations under the License. */
...
@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
)
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
)
{
service
->
RunSyncUpdate
();
service
->
StartServer
();
VLOG
(
4
)
<<
"RunServer thread end"
;
VLOG
(
4
)
<<
"RunServer thread end"
;
}
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
}
}
std
::
atomic_int
ListenAndServOp
::
selected_port_
{
0
};
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
void
ListenAndServOp
::
Stop
()
{
void
ListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
rpc_service_
->
ShutDown
();
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
server_thread_
->
join
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void
ListenAndServOp
::
SavePort
()
const
{
void
ListenAndServOp
::
SavePort
()
const
{
// NOTE: default write file to /tmp/paddle.selected_port
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_
=
rpc_service_
->
GetSelectedPort
();
rpc_service_
->
SavePort
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
.
load
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
ListenAndServOp
::
WaitServerReady
()
{
while
(
selected_port_
.
load
()
==
0
)
{
}
}
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
framework
::
BlockDesc
*
prefetch_block
)
const
{
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
"server program should have at least 2 blocks"
);
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared
.
begin
(),
optimize_prepared
.
begin
(),
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
bool
exit_flag
=
false
;
rpc_service_
->
ResetBarrierCounter
()
;
// Record received sparse variables, so that
// Record received sparse variables, so that
// we could reset those after execute optimize program
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
()
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
rpc_service_
->
SetCond
(
detail
::
kRequestSend
);
size_t
recv_var_cnt
=
0
;
rpc_service_
->
WaitBarrier
(
detail
::
kRequestSend
);
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
if
(
rpc_service_
->
IsExit
())
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
auto
recv_var_name
=
v
.
first
;
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
recv_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
}
}
}
if
(
exit_flag
)
{
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
ShutDown
();
break
;
break
;
}
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
}
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_
->
WaitBarrier
(
detail
::
kRequestGet
);
rpc_service_
->
WaitClientGet
(
fan_in
);
rpc_service_
->
ResetBarrierCounter
();
sparse_vars
.
clear
();
}
// while(true)
}
// while(true)
}
}
static
void
AsyncUpdateThread
(
const
std
::
string
&
var_name
,
const
bool
&
exit_flag
,
const
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
&
queue
,
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
)
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" started"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
queue
->
Pop
();
if
(
SignalHandler
::
IsProgramExit
())
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" exit"
;
break
;
}
auto
recv_var_name
=
v
.
first
;
VLOG
(
4
)
<<
"async update "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
auto
fs
=
framework
::
Async
([
var_name
,
&
executor
,
&
v
,
prepared
]
{
try
{
executor
->
RunPreparedContext
(
prepared
,
v
.
second
->
GetMutableLocalScope
());
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
fs
.
wait
();
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
)
const
{
framework
::
ProgramDesc
*
program
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
ReceivedQueue
>>
grad_to_queue
;
auto
grad_to_block_id_str
=
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
queue
=
std
::
make_shared
<
detail
::
ReceivedQueue
>
();
grad_to_queue
[
pieces
[
0
]]
=
queue
;
// record blocking queue in SignalHandler
SignalHandler
::
RegisterBlockingQueue
(
queue
);
id_to_grad
[
block_id
]
=
pieces
[
0
];
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
}
size_t
num_blocks
=
program
->
Size
();
size_t
num_blocks
=
program
->
Size
();
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
}
bool
exit_flag
=
false
;
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
VLOG
(
3
)
<<
"start async optimize threads"
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
iter
=
grad_to_queue
.
begin
();
iter
!=
grad_to_queue
.
end
();
iter
++
)
{
std
::
string
grad_name
=
iter
->
first
;
VLOG
(
3
)
<<
"create async update thread for "
<<
grad_name
;
fs
.
push_back
(
framework
::
AsyncIO
([
grad_name
,
&
exit_flag
,
&
executor
,
&
grad_to_queue
,
&
grad_to_prepared_ctx
]()
{
AsyncUpdateThread
(
grad_name
,
exit_flag
,
grad_to_queue
[
grad_name
],
executor
,
grad_to_prepared_ctx
[
grad_name
].
get
());
}));
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
while
(
true
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
if
(
rpc_service_
->
IsExit
())
{
auto
recv_var_name
=
v
.
first
;
LOG
(
INFO
)
<<
"get exit!rpc_processor break!"
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
grad_to_queue
[
recv_var_name
]
->
Push
(
v
);
}
}
if
(
exit_flag
)
{
sleep
(
1
);
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
// while(true)
}
}
static
void
FillRequestCtx
(
detail
::
RequestHandler
*
h
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
detail
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
std
::
move
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
(
prefetch_ctx
)));
h
->
SetRPCServer
(
rpc_server
);
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
// Mark this as PS that it should decide profiling by listening from trainer.
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
detail
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
auto
*
program
=
optimize_block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
// prepare rpc_service
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetProgram
(
program
);
rpc_service_
->
SetExecutor
(
&
executor
);
// prepare for prefetch
// prepare for prefetch
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prefetch_prepared
));
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
prefetch_prepared
.
release
(),
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
// start the server listening after all member initialized.
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
// Write to a file of server selected port for python use.
// Write to a file of server selected port for python use.
std
::
string
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.selected_port"
,
static_cast
<
int
>
(
::
getpid
()));
SavePort
();
SavePort
();
if
(
sync_mode
)
{
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
bool
SignalHandler
::
program_exit_flag_
=
false
;
SignalHandler
::
BlockingQueueSet
SignalHandler
::
blocking_queue_set_
{};
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
exit
(
0
);
program_exit_flag_
=
true
;
// awake all blocking queues
for
(
BlockingQueueSet
::
iterator
iter
=
blocking_queue_set_
.
begin
();
iter
!=
blocking_queue_set_
.
end
();
iter
++
)
{
iter
->
get
()
->
Push
(
std
::
make_pair
(
std
::
string
(
LISTEN_TERMINATE_MESSAGE
),
nullptr
));
}
exit
(
EXIT_SUCCESS
);
}
void
SignalHandler
::
RegisterBlockingQueue
(
BlockingQueue
&
queue
)
{
blocking_queue_set_
.
insert
(
queue
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
4fb7cc7f
...
@@ -23,7 +23,8 @@ limitations under the License. */
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -31,7 +32,7 @@ namespace operators {
...
@@ -31,7 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
);
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
public:
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void
SavePort
()
const
;
void
SavePort
()
const
;
void
WaitServerReady
();
int
GetSelectedPort
()
{
return
rpc_service_
->
GetSelectedPort
();
}
int
GetSelectedPort
()
{
return
selected_port_
;
}
void
Stop
()
override
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
const
platform
::
Place
&
dev_place
)
const
override
;
static
void
ResetPort
()
{
selected_port_
=
0
;
}
protected:
protected:
mutable
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
// FIXME(wuyi): it's static so that the operator can be cloned.
static
std
::
atomic_int
selected_port_
;
};
};
class
SignalHandler
{
class
SignalHandler
{
public:
typedef
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
BlockingQueue
;
typedef
std
::
unordered_set
<
BlockingQueue
>
BlockingQueueSet
;
public:
public:
static
void
StopAndExit
(
int
signal_num
);
static
void
StopAndExit
(
int
signal_num
);
static
void
RegisterBlockingQueue
(
BlockingQueue
&
);
static
inline
bool
IsProgramExit
()
{
return
program_exit_flag_
;
}
private:
private:
static
bool
program_exit_flag_
;
static
BlockingQueueSet
blocking_queue_set_
;
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
};
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
4fb7cc7f
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
if
(
sync_mode
)
{
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
4fb7cc7f
...
@@ -21,6 +21,8 @@ limitations under the License. */
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
string
=
paddle
::
string
;
namespace
string
=
paddle
::
string
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
void
StartServer
(
std
::
atomic
<
bool
>*
initialized
)
{
void
StartServer
()
{
f
::
Scope
scope
;
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUPlace
place
;
scope
.
Var
(
NCCL_ID_VARNAME
);
scope
.
Var
(
NCCL_ID_VARNAME
);
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
true
));
f
::
ProgramDesc
empty_program
;
f
::
ProgramDesc
empty_program
;
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_service
->
SetScope
(
&
scope
);
g_req_handler
->
SetScope
(
&
scope
);
rpc_service
->
SetDevCtx
(
&
dev_ctx
);
g_req_handler
->
SetDevCtx
(
&
dev_ctx
);
rpc_service
->
SetProgram
(
&
empty_program
);
g_req_handler
->
SetProgram
(
&
empty_program
);
rpc_service
->
SetExecutor
(
&
executor
);
g_req_handler
->
SetExecutor
(
&
executor
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestSend
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service
.
get
()));
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
*
initialized
=
true
;
rpc_service
->
SetCond
(
0
);
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
auto
recv
=
rpc_service
->
Get
();
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
rpc_service
->
ShutDown
();
g_
rpc_service
->
ShutDown
();
server_thread
.
join
();
server_thread
.
join
();
}
}
TEST
(
SendNcclId
,
DISABLED_Normal
)
{
TEST
(
SendNcclId
,
GrpcServer
)
{
std
::
atomic
<
bool
>
initialized
{
false
};
g_req_handler
.
reset
(
new
detail
::
RequestSendHandler
(
true
));
std
::
thread
server_thread
(
StartServer
,
&
initialized
);
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
while
(
!
initialized
)
{
}
std
::
thread
server_thread
(
StartServer
);
// wait server to start
g_rpc_service
->
WaitServerReady
();
// sleep(2);
rpc_service
->
WaitServerReady
();
f
::
Scope
scope
;
f
::
Scope
scope
;
p
::
CPUPlace
place
;
p
::
CPUPlace
place
;
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
// var->SetType(f::proto::VarType_Type_RAW);
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
p
::
dynload
::
ncclGetUniqueId
(
id
);
p
::
dynload
::
ncclGetUniqueId
(
id
);
int
port
=
rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
detail
::
RPCClient
client
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
.
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
server_thread
.
join
();
server_thread
.
join
();
auto
*
ptr
=
rpc_service
.
release
(
);
g_rpc_service
.
reset
(
nullptr
);
delete
ptr
;
g_req_handler
.
reset
(
nullptr
)
;
}
}
paddle/fluid/platform/nccl_helper.h
浏览文件 @
4fb7cc7f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <stdio.h>
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <typeindex>
#include <typeindex>
#include <vector>
#include <vector>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录