Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ae12281d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae12281d
编写于
6月 19, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
checkpoint notify
上级
30880844
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
41 addition
and
11 deletion
+41
-11
paddle/fluid/operators/checkpoint_notify_op.cc
paddle/fluid/operators/checkpoint_notify_op.cc
+7
-2
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+4
-1
paddle/fluid/operators/detail/request_handler_impl.cc
paddle/fluid/operators/detail/request_handler_impl.cc
+7
-0
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+10
-2
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+10
-5
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+3
-1
未找到文件。
paddle/fluid/operators/checkpoint_notify_op.cc
浏览文件 @
ae12281d
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase {
...
@@ -36,12 +37,14 @@ class CheckpointNotifyOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
std
::
string
lookup_table_name
=
Attr
<
std
::
string
>
(
"lookup_table"
);
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
*
rpc_client
=
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
detail
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
VLOG
(
3
)
<<
"sending to "
<<
epmap
[
i
]
<<
" to checkpoint notify ... "
;
VLOG
(
3
)
<<
"sending "
<<
dir
<<
" to "
<<
epmap
[
i
]
<<
" to checkpoint notify ... "
;
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
dir
);
auto
serial_looku_table
=
string
::
Sprintf
(
"%s/%s.%d"
,
dir
,
lookup_table_name
,
i
);
rpc_client
->
AsyncCheckpointNotify
(
epmap
[
i
],
serial_looku_table
);
}
}
rpc_client
->
Wait
();
rpc_client
->
Wait
();
}
}
...
@@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -57,6 +60,8 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
({
"127.0.0.1:6164"
});
.
SetDefault
({
"127.0.0.1:6164"
});
AddAttr
<
std
::
string
>
(
AddAttr
<
std
::
string
>
(
"dir"
,
"(string, default '') indicate the folder checkpoint will use"
);
"dir"
,
"(string, default '') indicate the folder checkpoint will use"
);
AddAttr
<
std
::
string
>
(
"lookup_table"
,
"(string, default '') the lookup table name"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Prefetch operator
Prefetch operator
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
ae12281d
...
@@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase {
...
@@ -208,11 +208,14 @@ class RequestCheckpointNotify final : public RequestBase {
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
scope
=
request_
->
GetMutableLocalScope
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_dir
=
request_
->
Varname
();
std
::
string
checkpoint_dir
=
request_
->
Out
Varname
();
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
;
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
invar
,
&
outvar
,
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
invar
,
&
outvar
,
checkpoint_dir
);
checkpoint_dir
);
Finish
(
reply_
,
&
responder_
);
Finish
(
reply_
,
&
responder_
);
...
...
paddle/fluid/operators/detail/request_handler_impl.cc
浏览文件 @
ae12281d
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
...
@@ -124,6 +125,12 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
)
{
auto
lt_varname
=
string
::
Sprintf
(
"%s.path"
,
varname
);
auto
*
lt_var
=
scope
->
FindVar
(
lt_varname
)
->
GetMutable
<
std
::
string
>
();
lt_var
->
clear
();
lt_var
->
append
(
out_var_name
);
VLOG
(
4
)
<<
"RequestCheckpointHandler update "
<<
lt_varname
<<
" to: "
<<
out_var_name
;
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
executor_
->
RunPreparedContext
(
checkpoint_prepared_ctx_
.
get
(),
scope
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
ae12281d
...
@@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase {
...
@@ -87,7 +87,7 @@ class SaveOp : public framework::OperatorBase {
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
SaveLodTensor
(
filename
,
place
,
var
);
SaveLodTensor
(
filename
,
place
,
var
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
SaveSelectedRows
(
filenam
e
,
place
,
var
);
SaveSelectedRows
(
scop
e
,
place
,
var
);
}
else
{
}
else
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
false
,
false
,
...
@@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase {
...
@@ -128,9 +128,17 @@ class SaveOp : public framework::OperatorBase {
fout
.
close
();
fout
.
close
();
}
}
void
SaveSelectedRows
(
const
std
::
string
&
filenam
e
,
void
SaveSelectedRows
(
const
framework
::
Scope
&
scop
e
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
framework
::
Variable
*
var
)
const
{
framework
::
Variable
*
var
)
const
{
auto
lt_varname
=
string
::
Sprintf
(
"%s.path"
,
Input
(
"X"
));
auto
*
lt_var
=
scope
.
FindVar
(
lt_varname
)
->
GetMutable
<
std
::
string
>
();
PADDLE_ENFORCE
(
lt_var
!=
nullptr
,
"Cannot find variable %s for SaveSelectedRows"
,
lt_varname
);
std
::
string
filename
=
lt_var
->
data
();
VLOG
(
4
)
<<
"SaveSelectedRows get File name: "
<<
filename
;
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
selectedRows
=
var
->
Get
<
framework
::
SelectedRows
>
();
// get device context from pool
// get device context from pool
...
...
python/paddle/fluid/io.py
浏览文件 @
ae12281d
...
@@ -471,7 +471,10 @@ def save_checkpoint(executor,
...
@@ -471,7 +471,10 @@ def save_checkpoint(executor,
trainer_id
,
trainer_id
,
trainer_args
=
None
,
trainer_args
=
None
,
main_program
=
None
,
main_program
=
None
,
max_num_checkpoints
=
3
):
max_num_checkpoints
=
3
,
lookup_table
=
None
,
ps_endpoint_list
=
None
):
"""
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
...
@@ -500,7 +503,7 @@ def save_checkpoint(executor,
...
@@ -500,7 +503,7 @@ def save_checkpoint(executor,
if
trainer_id
==
0
:
if
trainer_id
==
0
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
""
)
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
ps_endpoint_list
,
lookup_table
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
...
@@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
...
@@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success
(
cur_dir
)
_write_success
(
cur_dir
)
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
epmap
):
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
"""
"""
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
cur_dir
=
_get_lookuptable_dir
(
dirname
)
...
@@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
...
@@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
=
{}
attrs
[
'epmap'
]
=
None
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
output
=
{},
attrs
=
attrs
)
type
=
'checkpoint_notify'
,
inputs
=
{},
output
s
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
executor
.
run
(
checkpoint_notify_program
)
...
@@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
...
@@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
if
success_num
>
current_dir
:
if
success_num
>
current_dir
:
current_dir
=
success_num
current_dir
=
success_num
return
current_dir
return
current_dir
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
ae12281d
...
@@ -838,13 +838,15 @@ class DistributeTranspiler:
...
@@ -838,13 +838,15 @@ class DistributeTranspiler:
"""
"""
import
os
import
os
pserver_program
.
global_block
().
create_var
(
name
=
"%s.path"
%
self
.
table_name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
checkpoint_save_block
=
pserver_program
.
create_block
(
pre_block_idx
)
checkpoint_save_block
=
pserver_program
.
create_block
(
pre_block_idx
)
checkpoint_save_block
.
append_op
(
checkpoint_save_block
.
append_op
(
type
=
'save'
,
type
=
'save'
,
inputs
=
{
'X'
:
[
self
.
table_name
]},
inputs
=
{
'X'
:
[
self
.
table_name
]},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
'file_path'
:
os
.
path
.
join
(
"/tmp/pserver_ckpt/"
,
self
.
table_name
)
'file_path'
:
self
.
table_name
)
})
})
return
checkpoint_save_block
.
idx
return
checkpoint_save_block
.
idx
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录