Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
05bd9db8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05bd9db8
编写于
6月 20, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add comments in io.py
上级
c073bb3b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
92 addition
and
2 deletion
+92
-2
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+92
-2
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
05bd9db8
...
@@ -840,6 +840,12 @@ def save_checkpoint(executor,
...
@@ -840,6 +840,12 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
checkpoints.
Default: 3
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns:
Returns:
None
None
...
@@ -856,15 +862,21 @@ def save_checkpoint(executor,
...
@@ -856,15 +862,21 @@ def save_checkpoint(executor,
prog = fluid.default_main_program()
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe,
fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path,
checkpoint_dir=path,
trainer_id=0,
trainer_id=0,
trainer_args=trainer_args,
trainer_args=trainer_args,
main_program=prog,
main_program=prog,
max_num_checkpoints=3)
max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
"""
"""
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
assert
checkpoint_dir
if
trainer_args
:
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
assert
isinstance
(
trainer_args
,
dict
)
...
@@ -881,6 +893,7 @@ def save_checkpoint(executor,
...
@@ -881,6 +893,7 @@ def save_checkpoint(executor,
if
is_chief
:
if
is_chief
:
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
ps_endpoint_list
:
if
is_chief
and
lookup_table
and
ps_endpoint_list
:
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
ps_endpoint_list
)
ps_endpoint_list
)
...
@@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor,
...
@@ -1020,6 +1033,31 @@ def load_persist_vars_without_grad(executor,
def
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
def
load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
if
var
.
name
==
table_name
:
...
@@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program):
...
@@ -1092,6 +1130,35 @@ def save_persist_vars_without_grad(executor, dirname, program):
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
def
save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
ps_endpoint_list
):
"""
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
cur_dir
=
_get_lookuptable_dir
(
dirname
)
...
@@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
...
@@ -1121,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
def
load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
...
@@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var):
...
@@ -1141,7 +1231,7 @@ def _is_checkpoint_var(var):
the checkpoint will not save or load all the variables.
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var
: param var
(Variable)
"""
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录