Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4fb7cc7f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4fb7cc7f
编写于
5月 31, 2018
作者:
G
gongweibao
提交者:
GitHub
5月 31, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move sync_mode device ctx from grpc server (#10881)
上级
5870a6b4
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
886 addition
and
553 deletion
+886
-553
benchmark/fluid/kube_gen_job.py
benchmark/fluid/kube_gen_job.py
+1
-1
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+3
-0
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
...nference/analysis/data_flow_graph_to_fluid_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+3
-1
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
.../fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
+2
-0
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
...nference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+4
-2
paddle/fluid/inference/analysis/pass.h
paddle/fluid/inference/analysis/pass.h
+1
-0
paddle/fluid/inference/analysis/subgraph_splitter.h
paddle/fluid/inference/analysis/subgraph_splitter.h
+2
-0
paddle/fluid/inference/analysis/ut_helper.h
paddle/fluid/inference/analysis/ut_helper.h
+1
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+4
-1
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+2
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+2
-0
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+147
-225
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+22
-76
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+53
-34
paddle/fluid/operators/detail/request_handler.h
paddle/fluid/operators/detail/request_handler.h
+127
-0
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+115
-0
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+64
-0
paddle/fluid/operators/detail/rpc_server.cc
paddle/fluid/operators/detail/rpc_server.cc
+113
-0
paddle/fluid/operators/detail/rpc_server.h
paddle/fluid/operators/detail/rpc_server.h
+91
-0
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+2
-2
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+13
-8
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+63
-148
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+9
-22
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-0
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+33
-26
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+1
-0
未找到文件。
benchmark/fluid/kube_gen_job.py
浏览文件 @
4fb7cc7f
...
...
@@ -49,7 +49,7 @@ def parse_args():
parser
.
add_argument
(
'--fluid'
,
default
=
1
,
type
=
int
,
help
=
'whether is fluid job'
)
parser
.
add_argument
(
'--rdma'
,
action
=
'store_t
ur
e'
,
help
=
'whether mount rdma libs'
)
'--rdma'
,
action
=
'store_t
ru
e'
,
help
=
'whether mount rdma libs'
)
parser
.
add_argument
(
'--disttype'
,
default
=
"pserver"
,
...
...
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
4fb7cc7f
...
...
@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
浏览文件 @
4fb7cc7f
...
...
@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
}
}
//
analysis
}
//
inference
}
//
paddle
}
;
// namespace
analysis
}
;
// namespace
inference
}
;
// namespace
paddle
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
4fb7cc7f
...
...
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include
"paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include
<string>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
浏览文件 @
4fb7cc7f
...
...
@@ -19,6 +19,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
浏览文件 @
4fb7cc7f
...
...
@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG
(
INFO
)
<<
'\n'
<<
graph
.
DotString
();
}
}
// analysis
}
// inference
}
// paddle
}
//
namespace
analysis
}
//
namespace
inference
}
//
namespace
paddle
paddle/fluid/inference/analysis/helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -50,7 +50,7 @@ struct DataTypeNamer {
return
dic_
.
at
(
x
);
}
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
const
std
::
string
&
repr
(
size_t
&
hash
)
const
{
// NOLINT
PADDLE_ENFORCE
(
dic_
.
count
(
hash
),
"unknown type for representation"
);
return
dic_
.
at
(
hash
);
}
...
...
@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE
(
float
);
}
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
std
::
string
>
dic_
;
std
::
unordered_map
<
decltype
(
typeid
(
int
).
hash_code
()),
// NOLINT
std
::
string
>
dic_
;
};
#undef SET_TYPE
...
...
paddle/fluid/inference/analysis/pass.h
浏览文件 @
4fb7cc7f
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
...
...
paddle/fluid/inference/analysis/subgraph_splitter.h
浏览文件 @
4fb7cc7f
...
...
@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
...
...
paddle/fluid/inference/analysis/ut_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
...
...
@@ -58,7 +61,7 @@ class TRTConvertValidation {
public:
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1
<<
10
)
{
explicit
TRTConvertValidation
(
int
batch_size
,
int
workspace_size
=
1024
)
{
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
10
,
1
<<
10
,
&
stream_
));
engine_
->
InitNetwork
();
...
...
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
4fb7cc7f
if
(
WITH_DISTRIBUTE
)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
4fb7cc7f
...
...
@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
bool
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
<<
" req_count_:"
<<
req_count_
;
if
(
req_count_
<=
0
)
{
return
true
;
}
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
4fb7cc7f
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
#include "paddle/fluid/operators/detail/grpc_server.h"
DEFINE_int32
(
rpc_server_handle_send_threads
,
20
,
"Number of threads used to handle send at rpc server."
);
DEFINE_int32
(
rpc_server_handle_get_threads
,
20
,
"Number of threads used to handle get at rpc server."
);
DEFINE_int32
(
rpc_server_handle_prefetch_threads
,
1
,
"Number of threads used to handle prefetch at rpc server."
);
using
::
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
const
platform
::
DeviceContext
*
dev_ctx
)
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
service_
(
service
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
request_handler_
(
request_handler
),
req_id_
(
req_id
)
{
PADDLE_ENFORCE
(
cq_
);
}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
virtual
void
Process
()
=
0
;
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
return
""
;
}
virtual
std
::
string
GetReqName
()
=
0
;
protected:
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
RequestHandler
*
request_handler_
;
int
req_id_
;
};
class
RequestSend
final
:
public
RequestBase
{
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
),
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
!
request_handler
->
sync_mode
()));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
...
@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
std
::
string
varname
=
GetReqName
();
VLOG
(
3
)
<<
"RequestSend var_name:"
<<
varname
;
virtual
void
Process
()
{
std
::
string
var_name
=
GetReqName
();
VLOG
(
3
)
<<
"RequestSend "
<<
var_name
;
queue_
->
Push
(
std
::
make_pair
(
var_name
,
request_
));
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
request_
->
GetVar
();
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
...
@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
protected:
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
int
req_id_
;
};
class
RequestGet
final
:
public
RequestBase
{
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
),
req_id_
(
req_id
)
{
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestGet
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
.
varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// proc request.
std
::
string
var_name
=
request_
.
varname
();
VLOG
(
3
)
<<
"RequestGet "
<<
var_name
;
auto
*
var
=
scope_
->
FindVar
(
var_name
);
std
::
string
varname
=
request_
.
varname
();
VLOG
(
3
)
<<
"RequestGet "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
sendrecv
::
VariableMessage
msg
;
MessageWithName
msg_with_name
=
std
::
make_pair
(
var_name
,
msg
);
queue_
->
Push
(
msg_with_name
);
}
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
BlockingQueue
<
MessageWithName
>*
queue_
;
int
req_id_
;
};
class
RequestPrefetch
final
:
public
RequestBase
{
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
::
grpc
::
ServerCompletionQueue
*
cq
,
RequestHandler
*
request_handler
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
request_handler
,
req_id
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
),
req_id_
(
req_id
)
{
// prefetch always create a new sub scope
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
local_scope_
(
nullptr
)
{
request_
.
reset
(
new
VariableResponse
(
request_handler
->
scope
(),
request_handler
->
dev_ctx
(),
true
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestPrefetch
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
->
Varname
();
}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
v
irtual
void
Process
()
{
v
oid
Process
()
override
{
// prefetch process...
std
::
string
varname
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
varname
;
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
std
::
string
var_name
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
request_
->
GetMutableLocalScope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
local_scope
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
...
...
@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
std
::
shared_ptr
<
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
req_id_
;
framework
::
Scope
*
local_scope_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
int
fetch_barriers
=
0
;
while
(
fetch_barriers
<
count
)
{
auto
msg
=
var_get_queue_
.
Pop
();
if
(
msg
.
first
==
FETCH_BARRIER_MESSAGE
)
{
fetch_barriers
++
;
}
}
}
void
AsyncGRPCServer
::
WaitServerReady
()
{
VLOG
(
3
)
<<
"AsyncGRPCServer is wait server ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
condition_ready_
.
wait
(
lock
,
[
=
]
{
return
this
->
ready_
==
1
;
});
VLOG
(
3
)
<<
"AsyncGRPCServer WaitSeverReady"
;
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
void
AsyncGRPCServer
::
StartServer
()
{
::
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
::
grpc
::
InsecureServerCredentials
(),
builder
.
AddListeningPort
(
bind_
address_
,
::
grpc
::
InsecureServerCredentials
(),
&
selected_port_
);
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
cq_get_
=
builder
.
AddCompletionQueue
(
);
cq_prefetch_
=
builder
.
AddCompletionQueue
();
for
(
auto
t
:
rpc_call_map_
)
{
rpc_cq_
[
t
.
first
].
reset
(
builder
.
AddCompletionQueue
().
release
()
);
}
server_
=
builder
.
BuildAndStart
();
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
LOG
(
INFO
)
<<
"Server listening on "
<<
bind_
address_
<<
" selected port: "
<<
selected_port_
;
std
::
function
<
void
(
int
)
>
send_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
get_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
prefetch_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
f
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewOne
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
);
for
(
int
i
=
0
;
i
<
kSendReqsBufSize
;
++
i
)
{
TryToRegisterNewSendOne
(
i
);
}
for
(
int
i
=
0
;
i
<
kGetReqsBufSize
;
++
i
)
{
TryToRegisterNewGetOne
(
i
);
}
for
(
int
i
=
0
;
i
<
kPrefetchReqsBufSize
;
++
i
)
{
TryToRegisterNewPrefetchOne
(
i
);
}
for
(
auto
&
t
:
rpc_call_map_
)
{
auto
&
rpc_name
=
t
.
first
;
auto
&
cq
=
rpc_cq_
[
rpc_name
];
auto
threadnum
=
rpc_thread_num_
[
rpc_name
];
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
reqs
.
reserve
(
kRequestBufSize
);
for
(
int
i
=
0
;
i
<
kRequestBufSize
;
i
++
)
{
TryToRegisterNewOne
(
rpc_name
,
i
);
}
for
(
int
i
=
0
;
i
<
threadnum
;
i
++
)
{
rpc_threads_
[
rpc_name
].
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq
.
get
(),
rpc_name
,
f
)));
VLOG
(
3
)
<<
t
.
first
<<
" creates threads!"
;
}
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
}
condition_ready_
.
notify_all
();
// wait server
server_
->
Wait
();
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
[
i
]
->
join
();
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
[
i
]
->
join
();
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
[
i
]
->
join
();
for
(
auto
&
t
:
rpc_threads_
)
{
auto
&
threads
=
t
.
second
;
for
(
size_t
i
=
0
;
i
<
threads
.
size
();
++
i
)
{
threads
[
i
]
->
join
();
VLOG
(
3
)
<<
t
.
first
<<
" threads ends!"
;
}
}
}
void
AsyncGRPCServer
::
ShutdownQueue
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
cq_send_
->
Shutdown
();
cq_get_
->
Shutdown
()
;
cq_prefetch_
->
Shutdown
();
for
(
auto
&
t
:
rpc_cq_
)
{
t
.
second
->
Shutdown
();
VLOG
(
3
)
<<
t
.
first
<<
" shutdown!"
;
}
}
// This URL explains why shutdown is complicate:
void
AsyncGRPCServer
::
ShutDown
()
{
void
AsyncGRPCServer
::
ShutDownImpl
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
is_shut_down_
=
true
;
ShutdownQueue
();
VLOG
(
3
)
<<
"server_ shutdown!"
;
server_
->
Shutdown
();
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
(
int
i
)
{
void
AsyncGRPCServer
::
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
sync_mode_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
,
i
);
send_reqs_
[
i
]
=
static_cast
<
RequestBase
*>
(
send
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
VLOG
(
4
)
<<
"register send rpc_name:"
<<
rpc_name
<<
", handler:"
<<
rpc_call_map_
[
kRequestSend
];
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
auto
&
handler
=
rpc_call_map_
[
rpc_name
];
auto
&
cq
=
rpc_cq_
[
rpc_name
];
RequestBase
*
b
=
nullptr
;
if
(
rpc_name
==
kRequestSend
)
{
b
=
new
RequestSend
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestGet
)
{
b
=
new
RequestGet
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
if
(
rpc_name
==
kRequestPrefetch
)
{
b
=
new
RequestPrefetch
(
&
service_
,
cq
.
get
(),
handler
,
req_id
);
}
else
{
PADDLE_ENFORCE
(
false
,
"not surpported rpc"
);
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
&
var_get_queue_
,
req_id
);
get_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
get
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
(),
req_id
);
prefetch_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
prefetch
);
reqs
[
req_id
]
=
b
;
VLOG
(
4
)
<<
"Create Request
Prefetch status:"
<<
prefetch
->
Status
();
VLOG
(
4
)
<<
"Create Request
Send status:"
<<
b
->
Status
();
}
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq
_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
)
{
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc
_name
,
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
)
{
void
*
tag
=
NULL
;
bool
ok
=
false
;
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest
for "
<<
cq_name
<<
" wait N
ext"
;
VLOG
(
3
)
<<
"HandleRequest
"
<<
rpc_name
<<
" wait n
ext"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue
shutdown!"
;
LOG
(
INFO
)
<<
"CompletionQueue "
<<
rpc_name
<<
"
shutdown!"
;
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" get Next"
;
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" after WaitCond"
;
}
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
VLOG
(
3
)
<<
"HandleRequest "
<<
rpc_name
<<
", req_id:"
<<
req_id
<<
" get next"
;
auto
&
reqs
=
rpc_reqs_
[
rpc_name
];
RequestBase
*
base
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
if
(
cq_name
==
"cq_get"
)
{
base
=
get_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_send"
)
{
base
=
send_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_prefetch"
)
{
base
=
prefetch_reqs_
[
req_id
];
}
PADDLE_ENFORCE
(
req_id
>=
0
&&
req_id
<
kRequestBufSize
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
base
=
reqs
[
req_id
];
}
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
LOG
(
WARNING
)
<<
"completion queue:"
<<
rpc_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
(
req_id
);
TryToRegisterNewOne
(
r
pc_name
,
r
eq_id
);
delete
base
;
continue
;
}
VLOG
(
3
)
<<
"queue id:"
<<
rpc_name
<<
", req_id:"
<<
req_id
<<
", status:"
<<
base
->
Status
();
switch
(
base
->
Status
())
{
case
PROCESS
:
{
base
->
Process
();
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
break
;
}
case
FINISH
:
{
TryToRegisterNewOne
(
req_id
);
VLOG
(
4
)
<<
cq_name
<<
" FINISH status:"
<<
base
->
Status
();
TryToRegisterNewOne
(
rpc_name
,
req_id
);
delete
base
;
break
;
}
...
...
@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
}
}
void
AsyncGRPCServer
::
WaitCond
(
int
cond
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
barrier_cond_step_
==
cond
;
});
}
void
AsyncGRPCServer
::
SetCond
(
int
cond
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_cond_step_
=
cond
;
}
barrier_condition_
.
notify_all
();
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
4fb7cc7f
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <map>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
...
...
@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
...
@@ -37,106 +41,48 @@ namespace paddle {
namespace
operators
{
namespace
detail
{
typedef
std
::
pair
<
std
::
string
,
std
::
shared_ptr
<
VariableResponse
>>
ReceivedMessage
;
typedef
framework
::
BlockingQueue
<
ReceivedMessage
>
ReceivedQueue
;
typedef
std
::
pair
<
std
::
string
,
sendrecv
::
VariableMessage
>
MessageWithName
;
class
RequestBase
;
class
AsyncGRPCServer
final
{
class
AsyncGRPCServer
final
:
public
RPCServer
{
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
),
ready_
(
0
)
{}
~
AsyncGRPCServer
()
{}
void
WaitServerReady
();
void
RunSyncUpdate
();
// functions to sync server barrier status.
void
WaitCond
(
int
cond
);
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
const
ReceivedMessage
Get
()
{
return
this
->
var_recv_queue_
.
Pop
();
}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
RPCServer
(
address
,
client_num
),
ready_
(
0
)
{}
v
oid
Push
(
const
std
::
string
&
msg_name
)
{
this
->
var_recv_queue_
.
Push
(
std
::
make_pair
(
msg_name
,
nullptr
))
;
}
v
irtual
~
AsyncGRPCServer
()
{}
void
WaitServerReady
()
override
;
void
StartServer
()
override
;
void
ShutDown
();
private:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
rpc_name
,
std
::
function
<
void
(
const
std
::
string
&
,
int
)
>
TryToRegisterNewOne
);
protected:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
(
int
req_id
);
void
TryToRegisterNewGetOne
(
int
req_id
);
void
TryToRegisterNewPrefetchOne
(
int
req_id
);
void
TryToRegisterNewOne
(
const
std
::
string
&
rpc_name
,
int
req_id
);
void
ShutdownQueue
();
void
ShutDownImpl
()
override
;
private:
static
const
int
kSendReqsBufSize
=
100
;
static
const
int
kGetReqsBufSize
=
100
;
static
const
int
kPrefetchReqsBufSize
=
10
;
static
const
int
kRequestBufSize
=
100
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
RequestBase
*
send_reqs_
[
kSendReqsBufSize
];
RequestBase
*
get_reqs_
[
kGetReqsBufSize
];
RequestBase
*
prefetch_reqs_
[
kPrefetchReqsBufSize
];
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
framework
::
BlockingQueue
<
MessageWithName
>
var_get_queue_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
std
::
mutex
barrier_mutex_
;
mutable
int
barrier_cond_step_
;
std
::
condition_variable
barrier_condition_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_sends_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_gets_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_prefetchs_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
framework
::
ProgramDesc
*
program_
;
framework
::
Executor
*
executor_
;
int
selected_port_
;
std
::
mutex
mutex_ready_
;
std
::
condition_variable
condition_ready_
;
int
ready_
;
std
::
map
<
std
::
string
,
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>>
rpc_cq_
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>>
rpc_threads_
;
std
::
map
<
std
::
string
,
std
::
vector
<
RequestBase
*>>
rpc_reqs_
;
};
};
// namespace detail
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
4fb7cc7f
...
...
@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
detail
=
paddle
::
operators
::
detail
;
USE_OP
(
lookup_table
);
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
framework
::
BlockDesc
*
AppendPrefetchBlcok
(
framework
::
ProgramDesc
*
program
)
{
auto
root_block
=
program
->
MutableBlock
(
0
);
...
...
@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
void
StartServer
()
{
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
...
...
@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto
prepared
=
exe
.
Prepare
(
program
,
block
->
ID
());
InitTensorsOnServer
(
&
scope
,
&
place
,
10
);
rpc_service_
->
SetProgram
(
&
program
);
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
rpc_service_
->
SetDevCtx
(
&
ctx
);
rpc_service_
->
SetScope
(
&
scope
);
rpc_service_
->
SetExecutor
(
&
exe
);
g_req_handler
->
SetProgram
(
&
program
);
g_req_handler
->
SetPrefetchPreparedCtx
(
std
::
move
(
prepared
));
g_req_handler
->
SetDevCtx
(
&
ctx
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetExecutor
(
&
exe
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
rpc_service_
->
RunSyncUpdate
();
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
TEST
(
PREFETCH
,
DISABLED_CPU
)
{
// start up a server instance backend
std
::
thread
server_thread
(
StartServer
,
"127.0.0.1:8889"
);
sleep
(
2
);
TEST
(
PREFETCH
,
CPU
)
{
g_req_handler
.
reset
(
new
detail
::
RequestPrefetchHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
auto
client
=
detail
::
RPCClient
::
GetInstance
();
client
->
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
rpc_service_
->
ShutDown
();
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
{
// create var on local scope
int64_t
rows_numel
=
5
;
InitTensorsOnClient
(
&
scope
,
&
place
,
rows_numel
);
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
);
}
paddle/fluid/operators/detail/request_handler.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
constexpr
char
kRequestSend
[]
=
"RequestSend"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
class
RPCServer
;
class
RequestHandler
{
public:
explicit
RequestHandler
(
bool
sync_mode
)
:
sync_mode_
(
sync_mode
),
dev_ctx_
(
nullptr
),
executor_
(
nullptr
),
scope_
(
nullptr
),
program_
(
nullptr
),
rpc_server_
(
nullptr
)
{}
virtual
~
RequestHandler
()
{}
// Set attributes.
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
void
SetDevCtx
(
const
platform
::
DeviceContext
*
dev_ctx
)
{
dev_ctx_
=
dev_ctx
;
}
void
SetProgram
(
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
SetExecutor
(
framework
::
Executor
*
executor
)
{
executor_
=
executor
;
}
void
SetPrefetchPreparedCtx
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepared
)
{
prefetch_ctx_
.
reset
(
prepared
.
release
());
}
// Used for async.
void
SetGradToPreparedCtx
(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
g
)
{
grad_to_prepared_ctx_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
bool
sync_mode
()
{
return
sync_mode_
;
}
framework
::
Scope
*
scope
()
{
return
scope_
;
}
const
platform
::
DeviceContext
*
dev_ctx
()
{
return
dev_ctx_
;
}
framework
::
ExecutorPrepareContext
*
prefetch_ctx
()
{
return
prefetch_ctx_
.
get
();
}
framework
::
ProgramDesc
*
program
()
{
return
program_
;
}
framework
::
Executor
*
executor
()
{
return
executor_
;
}
std
::
vector
<
framework
::
Variable
*>&
sparse_vars
()
{
return
sparse_vars_
;
}
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
=
0
;
protected:
const
bool
sync_mode_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
framework
::
Executor
*
executor_
;
framework
::
Scope
*
scope_
;
framework
::
ProgramDesc
*
program_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
// Used for async.
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars_
;
RPCServer
*
rpc_server_
;
std
::
mutex
sparse_var_mutex_
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
bool
RequestSendHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Async
if
(
!
sync_mode_
)
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"async: run sub program error "
<<
e
.
what
();
return
false
;
}
return
true
;
}
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestSend
);
}
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
PADDLE_THROW
(
"sync: Can not find server side var"
);
return
false
;
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
sparse_var_mutex_
);
sparse_vars_
.
push_back
(
invar
);
}
}
return
true
;
}
bool
RequestGetHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
varname
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestGet
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
return
true
;
}
// FETCH_BARRIER_MESSAGE
if
(
sync_mode_
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
return
true
;
}
bool
RequestPrefetchHandler
::
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
varname
);
*
outvar
=
scope
->
FindVar
(
varname
);
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
.
get
(),
scope
);
return
true
;
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/request_handler_impl.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
explicit
RequestGetHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
)
override
;
};
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.cc
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
void
RPCServer
::
ShutDown
()
{
LOG
(
INFO
)
<<
"RPCServer ShutDown "
;
ShutDownImpl
();
exit_flag_
=
true
;
barrier_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
SavePort
()
const
{
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
=
]
{
return
(
barrier_counter_
[
rpc_name
]
>=
client_num_
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"batch_barrier_:"
<<
barrier_counter_
[
rpc_name
];
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
int
b
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
}
VLOG
(
3
)
<<
"RPCServer IncreaseBatchBarrier "
<<
rpc_name
<<
", barrier_count:"
<<
b
<<
", fan_in"
<<
client_num_
;
if
(
b
>=
client_num_
)
{
barrier_cond_
.
notify_all
();
}
}
void
RPCServer
::
ResetBarrierCounter
()
{
VLOG
(
3
)
<<
"RPCServer ResetBarrierCounter "
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
for
(
auto
&
t
:
barrier_counter_
)
{
t
.
second
=
0
;
}
}
void
RPCServer
::
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
)
{
rpc_call_map_
[
rpc_name
]
=
handler
;
rpc_thread_num_
[
rpc_name
]
=
thread_num
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer SetCond "
<<
rpc_name
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cur_cond_
=
rpc_cond_map_
[
rpc_name
];
}
rpc_cond_
.
notify_all
();
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"RPCServer WaitCond "
<<
rpc_name
;
int
cond
=
0
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cond
=
rpc_cond_map_
[
rpc_name
];
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
}
}
// namespace detail
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detail/rpc_server.h
0 → 100644
浏览文件 @
4fb7cc7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
class
RPCServer
{
public:
explicit
RPCServer
(
const
std
::
string
&
address
,
int
client_num
)
:
cur_cond_
(
0
),
bind_address_
(
address
),
exit_flag_
(
false
),
selected_port_
(
0
),
client_num_
(
client_num
)
{}
virtual
~
RPCServer
()
{}
virtual
void
StartServer
()
=
0
;
virtual
void
WaitServerReady
()
=
0
;
void
ShutDown
();
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
void
SavePort
()
const
;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void
WaitBarrier
(
const
std
::
string
&
rpc_name
);
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
ResetBarrierCounter
();
protected:
virtual
void
ShutDownImpl
()
=
0
;
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
int
>
barrier_counter_
;
std
::
condition_variable
barrier_cond_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_cond_map_
;
std
::
atomic
<
int
>
cur_cond_
;
std
::
condition_variable
rpc_cond_
;
protected:
std
::
string
bind_address_
;
std
::
atomic
<
int
>
exit_flag_
;
int
selected_port_
;
const
int
client_num_
;
std
::
unordered_map
<
std
::
string
,
RequestHandler
*>
rpc_call_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
rpc_thread_num_
;
friend
class
RequestHandler
;
};
};
// namespace detail
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
4fb7cc7f
...
...
@@ -67,8 +67,8 @@ class VariableResponse {
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
{
return
meta_
.
out_varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
// should call parse first.
framework
::
Variable
*
GetVar
()
{
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace
paddle
{
...
...
@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
true
);
detail
::
RequestSendHandler
rpc_h
(
true
);
detail
::
AsyncGRPCServer
rpc_service
(
endpoint
,
1
);
rpc_service
.
RegisterRPC
(
detail
::
kRequestSend
,
&
rpc_h
);
rpc_h
.
SetRPCServer
(
&
rpc_service
);
framework
::
ProgramDesc
empty_program
;
framework
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_
service
.
SetScope
(
scope
);
rpc_
service
.
SetDevCtx
(
&
dev_ctx
);
rpc_
service
.
SetProgram
(
&
empty_program
);
rpc_
service
.
SetExecutor
(
&
executor
);
rpc_
h
.
SetScope
(
scope
);
rpc_
h
.
SetDevCtx
(
&
dev_ctx
);
rpc_
h
.
SetProgram
(
&
empty_program
);
rpc_
h
.
SetExecutor
(
&
executor
);
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
&
rpc_service
));
rpc_service
.
SetCond
(
0
);
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
&
rpc_service
));
rpc_service
.
SetCond
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
auto
recv
=
rpc_service
.
Get
(
);
rpc_service
.
WaitBarrier
(
detail
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
.
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
)
{
service
->
RunSyncUpdate
();
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
)
{
service
->
StartServer
();
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
...
...
@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
}
std
::
atomic_int
ListenAndServOp
::
selected_port_
{
0
};
ListenAndServOp
::
ListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
...
@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp
::~
ListenAndServOp
()
{
Stop
();
}
void
ListenAndServOp
::
Stop
()
{
rpc_service_
->
Push
(
LISTEN_TERMINATE_MESSAGE
);
rpc_service_
->
ShutDown
();
server_thread_
->
join
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
...
...
@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void
ListenAndServOp
::
SavePort
()
const
{
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_
=
rpc_service_
->
GetSelectedPort
();
auto
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.port"
,
::
getpid
());
std
::
ofstream
port_file
;
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
.
load
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
}
void
ListenAndServOp
::
WaitServerReady
()
{
while
(
selected_port_
.
load
()
==
0
)
{
}
rpc_service_
->
SavePort
();
}
void
ListenAndServOp
::
RunSyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
...
...
@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared
.
begin
(),
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
bool
exit_flag
=
false
;
rpc_service_
->
ResetBarrierCounter
()
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
()
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
0
);
size_t
recv_var_cnt
=
0
;
int
batch_barrier
=
0
;
while
(
batch_barrier
!=
fan_in
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
if
(
recv_var_name
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"recv batch barrier message"
;
batch_barrier
++
;
continue
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
recv_var_cnt
++
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
}
}
}
if
(
exit_flag
)
{
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
ShutDown
();
rpc_service_
->
SetCond
(
detail
::
kRequestSend
);
rpc_service_
->
WaitBarrier
(
detail
::
kRequestSend
);
if
(
rpc_service_
->
IsExit
())
{
LOG
(
WARNING
)
<<
"get exit!rpc_processor break!"
;
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
break
;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
...
...
@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
rpc_service_
->
SetCond
(
1
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_
->
WaitClientGet
(
fan_in
);
sparse_vars
.
clear
();
rpc_service_
->
SetCond
(
detail
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
detail
::
kRequestGet
);
rpc_service_
->
ResetBarrierCounter
();
}
// while(true)
}
static
void
AsyncUpdateThread
(
const
std
::
string
&
var_name
,
const
bool
&
exit_flag
,
const
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
&
queue
,
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
)
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" started"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
queue
->
Pop
();
if
(
SignalHandler
::
IsProgramExit
())
{
VLOG
(
3
)
<<
"update thread for "
<<
var_name
<<
" exit"
;
break
;
}
auto
recv_var_name
=
v
.
first
;
VLOG
(
4
)
<<
"async update "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
auto
fs
=
framework
::
Async
([
var_name
,
&
executor
,
&
v
,
prepared
]
{
try
{
executor
->
RunPreparedContext
(
prepared
,
v
.
second
->
GetMutableLocalScope
());
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
fs
.
wait
();
}
}
void
ListenAndServOp
::
RunAsyncLoop
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
)
const
{
VLOG
(
3
)
<<
"RunAsyncLoop in"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
detail
::
ReceivedQueue
>>
grad_to_queue
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
...
...
@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG
(
3
)
<<
"after split, grad = "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
queue
=
std
::
make_shared
<
detail
::
ReceivedQueue
>
();
grad_to_queue
[
pieces
[
0
]]
=
queue
;
// record blocking queue in SignalHandler
SignalHandler
::
RegisterBlockingQueue
(
queue
);
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
...
...
@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
bool
exit_flag
=
false
;
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_get_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
request_prefetch_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
VLOG
(
3
)
<<
"start async optimize threads"
;
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
iter
=
grad_to_queue
.
begin
();
iter
!=
grad_to_queue
.
end
();
iter
++
)
{
std
::
string
grad_name
=
iter
->
first
;
VLOG
(
3
)
<<
"create async update thread for "
<<
grad_name
;
fs
.
push_back
(
framework
::
AsyncIO
([
grad_name
,
&
exit_flag
,
&
executor
,
&
grad_to_queue
,
&
grad_to_prepared_ctx
]()
{
AsyncUpdateThread
(
grad_name
,
exit_flag
,
grad_to_queue
[
grad_name
],
executor
,
grad_to_prepared_ctx
[
grad_name
].
get
());
}));
}
VLOG
(
3
)
<<
"RunAsyncLoop into while"
;
while
(
!
exit_flag
&&
!
SignalHandler
::
IsProgramExit
())
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
while
(
true
)
{
if
(
rpc_service_
->
IsExit
())
{
LOG
(
INFO
)
<<
"get exit!rpc_processor break!"
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
grad_to_queue
[
recv_var_name
]
->
Push
(
v
);
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
sleep
(
1
);
}
// while(true)
}
static
void
FillRequestCtx
(
detail
::
RequestHandler
*
h
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
detail
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
h
->
SetDevCtx
(
dev_ctx
);
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
std
::
move
(
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
(
prefetch_ctx
)));
h
->
SetRPCServer
(
rpc_server
);
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
// Mark this as PS that it should decide profiling by listening from trainer.
...
...
@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
detail
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
detail
::
RequestGetHandler
(
sync_mode
));
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestPrefetch
,
request_prefetch_handler_
.
get
());
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
// prepare rpc_service
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
rpc_service_
->
SetProgram
(
program
);
rpc_service_
->
SetExecutor
(
&
executor
);
// prepare for prefetch
VLOG
(
3
)
<<
"prefetch block id is "
<<
prefetch_block
->
ID
();
auto
prefetch_prepared
=
executor
.
Prepare
(
*
program
,
prefetch_block
->
ID
());
rpc_service_
->
SetPrefetchPreparedCtx
(
std
::
move
(
prefetch_prepared
));
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
prefetch_prepared
.
release
(),
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
f
(
request_prefetch_handler_
.
get
());
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
...
...
@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal
(
SIGTERM
,
SignalHandler
::
StopAndExit
);
// Write to a file of server selected port for python use.
std
::
string
file_path
=
string
::
Sprintf
(
"/tmp/paddle.%d.selected_port"
,
static_cast
<
int
>
(
::
getpid
()));
SavePort
();
if
(
sync_mode
)
{
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
...
...
@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
bool
SignalHandler
::
program_exit_flag_
=
false
;
SignalHandler
::
BlockingQueueSet
SignalHandler
::
blocking_queue_set_
{};
void
SignalHandler
::
StopAndExit
(
int
signal_num
)
{
VLOG
(
3
)
<<
"Catch interrupt signal: "
<<
signal_num
<<
", program will exit"
;
program_exit_flag_
=
true
;
// awake all blocking queues
for
(
BlockingQueueSet
::
iterator
iter
=
blocking_queue_set_
.
begin
();
iter
!=
blocking_queue_set_
.
end
();
iter
++
)
{
iter
->
get
()
->
Push
(
std
::
make_pair
(
std
::
string
(
LISTEN_TERMINATE_MESSAGE
),
nullptr
));
}
exit
(
EXIT_SUCCESS
);
}
void
SignalHandler
::
RegisterBlockingQueue
(
BlockingQueue
&
queue
)
{
blocking_queue_set_
.
insert
(
queue
);
exit
(
0
);
}
}
// namespace operators
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
4fb7cc7f
...
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -31,7 +32,7 @@ namespace operators {
constexpr
char
kOptimizeBlock
[]
=
"OptimizeBlock"
;
constexpr
char
kPrefetchBlock
[]
=
"PrefetchBlock"
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncG
RPCServer
>
service
);
void
RunServer
(
std
::
shared_ptr
<
detail
::
RPCServer
>
service
);
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
...
...
@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void
SavePort
()
const
;
void
WaitServerReady
();
int
GetSelectedPort
()
{
return
selected_port_
;
}
int
GetSelectedPort
()
{
return
rpc_service_
->
GetSelectedPort
();
}
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
;
static
void
ResetPort
()
{
selected_port_
=
0
;
}
protected:
mutable
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RPCServer
>
rpc_service_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_send_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_get_handler_
;
mutable
std
::
shared_ptr
<
detail
::
RequestHandler
>
request_prefetch_handler_
;
mutable
std
::
shared_ptr
<
std
::
thread
>
server_thread_
;
// FIXME(wuyi): it's static so that the operator can be cloned.
static
std
::
atomic_int
selected_port_
;
};
class
SignalHandler
{
public:
typedef
std
::
shared_ptr
<
detail
::
ReceivedQueue
>
BlockingQueue
;
typedef
std
::
unordered_set
<
BlockingQueue
>
BlockingQueueSet
;
public:
static
void
StopAndExit
(
int
signal_num
);
static
void
RegisterBlockingQueue
(
BlockingQueue
&
);
static
inline
bool
IsProgramExit
()
{
return
program_exit_flag_
;
}
private:
static
bool
program_exit_flag_
;
static
BlockingQueueSet
blocking_queue_set_
;
DISABLE_COPY_AND_ASSIGN
(
SignalHandler
);
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
4fb7cc7f
...
...
@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
if
(
sync_mode
)
{
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
4fb7cc7f
...
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
...
@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace
detail
=
paddle
::
operators
::
detail
;
namespace
string
=
paddle
::
string
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
rpc_service
;
std
::
unique_ptr
<
detail
::
AsyncGRPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
detail
::
RequestHandler
>
g_req_handler
;
void
StartServer
(
std
::
atomic
<
bool
>*
initialized
)
{
void
StartServer
()
{
f
::
Scope
scope
;
p
::
CPUPlace
place
;
scope
.
Var
(
NCCL_ID_VARNAME
);
p
::
DeviceContextPool
&
pool
=
p
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
true
));
f
::
ProgramDesc
empty_program
;
f
::
Executor
executor
(
dev_ctx
.
GetPlace
());
rpc_service
->
SetScope
(
&
scope
);
rpc_service
->
SetDevCtx
(
&
dev_ctx
);
rpc_service
->
SetProgram
(
&
empty_program
);
rpc_service
->
SetExecutor
(
&
executor
);
g_req_handler
->
SetScope
(
&
scope
);
g_req_handler
->
SetDevCtx
(
&
dev_ctx
);
g_req_handler
->
SetProgram
(
&
empty_program
);
g_req_handler
->
SetExecutor
(
&
executor
);
g_rpc_service
->
RegisterRPC
(
detail
::
kRequestSend
,
g_req_handler
.
get
());
g_req_handler
->
SetRPCServer
(
g_rpc_service
.
get
());
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
RunSyncUpdate
,
rpc_service
.
get
()));
*
initialized
=
true
;
rpc_service
->
SetCond
(
0
);
auto
recv
=
rpc_service
->
Get
();
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
rpc_service
->
ShutDown
();
g_
rpc_service
->
ShutDown
();
server_thread
.
join
();
}
TEST
(
SendNcclId
,
DISABLED_Normal
)
{
std
::
atomic
<
bool
>
initialized
{
false
};
std
::
thread
server_thread
(
StartServer
,
&
initialized
);
while
(
!
initialized
)
{
}
// wait server to start
// sleep(2);
rpc_service
->
WaitServerReady
();
TEST
(
SendNcclId
,
GrpcServer
)
{
g_req_handler
.
reset
(
new
detail
::
RequestSendHandler
(
true
));
g_rpc_service
.
reset
(
new
detail
::
AsyncGRPCServer
(
"127.0.0.1:0"
,
1
));
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
f
::
Scope
scope
;
p
::
CPUPlace
place
;
...
...
@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto
&
dev_ctx
=
*
pool
.
Get
(
p
::
CPUPlace
());
auto
var
=
scope
.
Var
(
NCCL_ID_VARNAME
);
// var->SetType(f::proto::VarType_Type_RAW);
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
p
::
dynload
::
ncclGetUniqueId
(
id
);
int
port
=
rpc_service
->
GetSelectedPort
();
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
server_thread
.
join
();
auto
*
ptr
=
rpc_service
.
release
(
);
delete
ptr
;
g_rpc_service
.
reset
(
nullptr
);
g_req_handler
.
reset
(
nullptr
)
;
}
paddle/fluid/platform/nccl_helper.h
浏览文件 @
4fb7cc7f
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <vector>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录