Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d13ce358
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d13ce358
编写于
3月 14, 2018
作者:
武
武毅
提交者:
gongweibao
3月 14, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Feature/send recv can now retry (#9027)
上级
14fe40aa
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
83 addition
and
25 deletion
+83
-25
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+15
-3
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+26
-10
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+15
-6
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+1
-1
paddle/fluid/operators/detail/sendrecvop_utils.h
paddle/fluid/operators/detail/sendrecvop_utils.h
+1
-0
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+2
-2
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+6
-0
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+17
-3
未找到文件。
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
d13ce358
...
...
@@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
}
bool
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
RPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
...
...
@@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
req_count_
++
;
}
return
true
;
void
RPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
req_count_
++
;
}
bool
RPCClient
::
Wait
()
{
...
...
@@ -154,7 +164,7 @@ bool RPCClient::Proceed() {
PADDLE_ENFORCE
(
tag
);
// TODO(gongwb): add more retries.
ClientBase
*
c
=
static_cast
<
ClientBase
*>
(
tag
);
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
if
(
!
c
->
status_
.
ok
())
{
LOG
(
ERROR
)
<<
"proc param error:"
<<
c
->
var_h_
.
String
()
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
...
...
@@ -174,6 +184,8 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}
grpc
::
ChannelArguments
args
;
args
.
SetInt
(
"grpc.testing.fixed_reconnect_backoff_ms"
,
5000
);
args
.
SetCompressionAlgorithm
(
GRPC_COMPRESS_NONE
);
args
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
args
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
d13ce358
...
...
@@ -52,14 +52,14 @@ struct VarHandle {
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
sendrecv
::
VariableMessage
&
msg
);
class
ClientBase
{
class
BaseProcessor
{
public:
explicit
ClientBase
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
explicit
BaseProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
{
stub_
=
sendrecv
::
SendRecvService
::
NewStub
(
ch
);
context_
=
NULL
;
}
virtual
~
ClientBase
()
{}
virtual
~
BaseProcessor
()
{}
virtual
void
Prepare
(
const
VarHandle
&
var_info
,
int64_t
time_out
)
{
context_
.
reset
(
new
grpc
::
ClientContext
());
...
...
@@ -91,9 +91,10 @@ class ClientBase {
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VoidMessage
&
)
>
RequestSendCallBack
;
class
SendProcessor
:
public
ClientBase
{
class
SendProcessor
:
public
BaseProcessor
{
public:
explicit
SendProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
ClientBase
(
ch
)
{}
explicit
SendProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
virtual
~
SendProcessor
()
{}
...
...
@@ -110,9 +111,10 @@ class SendProcessor : public ClientBase {
typedef
std
::
function
<
void
(
const
VarHandle
&
,
const
sendrecv
::
VariableMessage
&
)
>
RequestGetCallBack
;
class
GetProcessor
:
public
ClientBase
{
class
GetProcessor
:
public
BaseProcessor
{
public:
explicit
GetProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
ClientBase
(
ch
)
{}
explicit
GetProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
virtual
~
GetProcessor
()
{}
...
...
@@ -126,10 +128,10 @@ class GetProcessor : public ClientBase {
RequestGetCallBack
response_call_back_
=
ProcGetResponse
;
};
class
BatchBarrierProcessor
:
public
ClientBase
{
class
BatchBarrierProcessor
:
public
BaseProcessor
{
public:
explicit
BatchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
ClientBase
(
ch
)
{}
:
BaseProcessor
(
ch
)
{}
virtual
~
BatchBarrierProcessor
()
{}
...
...
@@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase {
sendrecv
::
VoidMessage
reply_
;
};
class
FetchBarrierProcessor
:
public
BaseProcessor
{
public:
explicit
FetchBarrierProcessor
(
std
::
shared_ptr
<
grpc
::
Channel
>
ch
)
:
BaseProcessor
(
ch
)
{}
virtual
~
FetchBarrierProcessor
()
{}
virtual
void
Process
()
{}
sendrecv
::
VariableMessage
reply_
;
};
class
RPCClient
{
public:
bool
AsyncSendVariable
(
const
std
::
string
&
ep
,
...
...
@@ -151,7 +164,10 @@ class RPCClient {
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
600
*
1000
);
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
600
*
1000
);
bool
Wait
();
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
d13ce358
...
...
@@ -84,7 +84,7 @@ class RequestGet final : public RequestBase {
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
SimpleBlockQueue
<
char
>*
queue
)
SimpleBlockQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
scope_
(
scope
),
...
...
@@ -101,11 +101,16 @@ class RequestGet final : public RequestBase {
// proc request.
std
::
string
var_name
=
request_
.
varname
();
auto
*
var
=
scope_
->
FindVar
(
var_name
);
SerializeToMessage
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
SerializeToMessage
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
}
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
queue_
->
Push
(
'c'
);
MessageWithName
msg_with_name
=
// request name reply
std
::
make_pair
(
var_name
,
std
::
move
(
reply_
));
queue_
->
Push
(
msg_with_name
);
}
protected:
...
...
@@ -114,12 +119,16 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
SimpleBlockQueue
<
char
>*
queue_
;
SimpleBlockQueue
<
MessageWithName
>*
queue_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
var_get_queue_
.
Pop
();
int
fetch_barriers
=
0
;
while
(
fetch_barriers
<
count
)
{
auto
msg
=
var_get_queue_
.
Pop
();
if
(
msg
.
first
==
FETCH_BARRIER_MESSAGE
)
{
fetch_barriers
++
;
}
}
}
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
d13ce358
...
...
@@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
char
>
var_get_queue_
;
SimpleBlockQueue
<
MessageWithName
>
var_get_queue_
;
// condition of the sub program
std
::
mutex
barrier_mutex_
;
...
...
paddle/fluid/operators/detail/sendrecvop_utils.h
浏览文件 @
d13ce358
...
...
@@ -32,6 +32,7 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
typedef
void
(
*
DestroyCallback
)(
void
*
);
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
d13ce358
...
...
@@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
ShutDown
();
break
;
}
try
{
...
...
@@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase {
}
rpc_service_
->
SetCond
(
1
);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_
->
WaitClientGet
(
ins
.
size
()
);
rpc_service_
->
WaitClientGet
(
fan_in
);
sparse_vars
.
clear
();
}
// while(true)
}
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
d13ce358
...
...
@@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase {
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
// tell pservers that current trainer have called fetch
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"send fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
());
}
}
};
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
d13ce358
...
...
@@ -250,6 +250,8 @@ class DistributeTranspiler:
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
self
.
program
.
global_block
().
delete_ops
(
self
.
optimize_ops
)
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self
.
program
.
__str__
()
return
self
.
program
def
get_pserver_program
(
self
,
endpoint
):
...
...
@@ -309,7 +311,8 @@ class DistributeTranspiler:
for
_
,
opt_op
in
enumerate
(
opt_op_on_pserver
):
if
ufind
.
is_connected
(
op
,
opt_op
):
if
self
.
_is_opt_op
(
op
):
self
.
_append_pserver_ops
(
optimize_block
,
op
,
endpoint
)
self
.
_append_pserver_ops
(
optimize_block
,
op
,
endpoint
,
default_main_program
())
else
:
self
.
_append_pserver_non_opt_ops
(
optimize_block
,
op
)
break
...
...
@@ -520,7 +523,8 @@ class DistributeTranspiler:
orig_var_name
=
varname
[:
suff_idx
]
return
orig_var_name
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
):
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
origin_program
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
dict
()
...
...
@@ -576,7 +580,17 @@ class DistributeTranspiler:
elif
key
==
"LearningRate"
:
# leraning rate variable has already be created by non-optimize op,
# don't create it once again.
new_inputs
[
key
]
=
pserver_block
.
vars
[
opt_op
.
input
(
key
)[
0
]]
lr_varname
=
opt_op
.
input
(
key
)[
0
]
if
pserver_block
.
vars
.
has_key
(
lr_varname
):
new_inputs
[
key
]
=
pserver_block
.
vars
[
opt_op
.
input
(
key
)[
0
]]
else
:
origin_var
=
origin_program
.
global_block
().
vars
[
lr_varname
]
tmpvar
=
pserver_block
.
create_var
(
name
=
origin_var
.
name
,
persistable
=
origin_var
.
persistable
,
dtype
=
origin_var
.
dtype
,
shape
=
origin_var
.
shape
)
new_inputs
[
key
]
=
tmpvar
for
key
in
opt_op
.
input_names
:
new_shape
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录