Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d93dc81c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d93dc81c
编写于
6月 19, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add handle when checkpoint_notify_id = -1
上级
1571c25a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
23 addition
and
12 deletion
+23
-12
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+6
-2
paddle/fluid/operators/detail/request_handler_impl.h
paddle/fluid/operators/detail/request_handler_impl.h
+7
-2
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+10
-8
未找到文件。
paddle/fluid/operators/detail/request_handler_impl.cc
浏览文件 @
d93dc81c
...
@@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
...
@@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
)
{
PADDLE_ENFORCE
(
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
auto
*
lt_var
=
scope
->
FindVar
(
"loopup_table_path"
)
->
GetMutable
<
std
::
string
>
();
auto
*
lt_var
=
scope
->
FindVar
(
"loopup_table_path"
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update loopup_table_path to: "
<<
out_var_name
;
VLOG
(
4
)
<<
"RequestCheckpointHandler update loopup_table_path to: "
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/detail/request_handler_impl.h
浏览文件 @
d93dc81c
...
@@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler {
...
@@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler {
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestCheckpointHandler
(
bool
sync_mode
)
explicit
RequestCheckpointHandler
(
bool
sync_mode
,
int
checkpoint_notify_id
)
:
RequestHandler
(
sync_mode
)
{}
:
RequestHandler
(
sync_mode
)
{
this
.
checkpoint_notify_id
=
checkpoint_notify_id
;
}
virtual
~
RequestCheckpointHandler
()
{}
virtual
~
RequestCheckpointHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
=
""
)
override
;
const
std
::
string
&
out_var_name
=
""
)
override
;
private:
int
checkpoint_notify_id
;
};
};
}
// namespace detail
}
// namespace detail
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
d93dc81c
...
@@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
int
checkpoint_point_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
LOG
(
INFO
)
<<
"sync_mode:"
<<
sync_mode
<<
", fan_in:"
<<
fan_in
<<
", end_point:"
<<
endpoint
;
<<
", end_point:"
<<
endpoint
<<
", CheckpointNotify Id: "
<<
checkpoint_notify_id
;
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
...
@@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_
.
reset
(
request_prefetch_handler_
.
reset
(
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
new
detail
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
request_checkpoint_handler_
.
reset
(
new
detail
::
RequestCheckpointHandler
(
sync_mode
));
new
detail
::
RequestCheckpointHandler
(
sync_mode
,
checkpoint_notify_id
));
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestSend
,
request_send_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
rpc_service_
->
RegisterRPC
(
detail
::
kRequestGet
,
request_get_handler_
.
get
());
...
@@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
->
RegisterRPC
(
detail
::
kRequestCheckpoint
,
rpc_service_
->
RegisterRPC
(
detail
::
kRequestCheckpoint
,
request_checkpoint_handler_
.
get
());
request_checkpoint_handler_
.
get
());
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
nullptr
;
if
(
checkpoint_notify_id
!=
-
1
)
{
auto
ctx
=
executor
.
Prepare
(
*
program
,
checkpoint_point_block_id
);
ckpt_pre_context
=
std
::
move
(
ctx
);
}
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
program
=
optimize_block
->
Program
();
auto
*
program
=
optimize_block
->
Program
();
framework
::
Executor
executor
(
dev_place
);
framework
::
Executor
executor
(
dev_place
);
...
@@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
}
}
int
checkpoint_point_block_id
=
Attr
<
int
>
(
kCheckpointBlockId
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
checkpoint_point_block_id
);
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ckpt_pre_context
=
std
::
move
(
ctx
);
auto
f
=
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录