Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0deb6f90
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0deb6f90
编写于
5月 30, 2018
作者:
T
tangwei12
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
annotation optimized and code style optimized
上级
0211c5df
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
27 addition
and
7 deletion
+27
-7
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+21
-1
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+6
-6
未找到文件。
python/paddle/fluid/io.py
浏览文件 @
0deb6f90
...
@@ -478,9 +478,10 @@ def save_checkpoint(executor,
...
@@ -478,9 +478,10 @@ def save_checkpoint(executor,
:param executor
:param executor
:param checkpoint_dir
:param checkpoint_dir
:param trainer_id
:param is_chief
:param main_program
:param main_program
:param max_num_checkpoints
:param max_num_checkpoints
:param is_chief
"""
"""
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
...
@@ -502,6 +503,11 @@ def save_checkpoint(executor,
...
@@ -502,6 +503,11 @@ def save_checkpoint(executor,
def
need_load_checkpoint
(
checkpoint_dir
):
def
need_load_checkpoint
(
checkpoint_dir
):
"""
If the directory have checkpoint files, it will return lastest checkpoint directory serial number
:param checkpoint_dir
"""
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
serial
=
_get_lastest_checkpoint_dir
(
checkpoint_dir
)
if
serial
<
0
:
if
serial
<
0
:
return
None
return
None
...
@@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
...
@@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
:param executor
:param executor
:param checkpoint_dir
:param checkpoint_dir
:param serial
:param main_program
:param main_program
"""
"""
...
@@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
...
@@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
:param checkpoint_dir
:param delete_dir
"""
"""
if
checkpoint_dir
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
raise
ValueError
(
"The values of 'checkpoint_dir' should not be None"
)
_lru_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
_lru_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
...
@@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True):
...
@@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True):
"""
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
the variable named end with "@GRAD" will not be loaded.
:param executor
:param dirname
:param program
:param nest
"""
"""
if
nest
:
if
nest
:
...
@@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
...
@@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
"""
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
the variable named end with "@GRAD" will not be saved.
:param executor
:param dirname
:param program
"""
"""
cur_dir
=
_get_model_dir
(
dirname
)
cur_dir
=
_get_model_dir
(
dirname
)
save_vars
(
save_vars
(
...
...
python/paddle/fluid/trainer.py
浏览文件 @
0deb6f90
...
@@ -79,8 +79,8 @@ class CheckpointConfig(object):
...
@@ -79,8 +79,8 @@ class CheckpointConfig(object):
else
:
else
:
self
.
step_interval
=
step_interval
self
.
step_interval
=
step_interval
self
.
_
epoch_id
=
0
self
.
epoch_id
=
0
self
.
_
step_id
=
0
self
.
step_id
=
0
self
.
_load_serial
=
None
self
.
_load_serial
=
None
...
@@ -185,8 +185,8 @@ class Trainer(object):
...
@@ -185,8 +185,8 @@ class Trainer(object):
epoch_id
,
step_id
=
io
.
load_trainer_args
(
epoch_id
,
step_id
=
io
.
load_trainer_args
(
self
.
checkpoint
.
checkpoint_dir
,
self
.
checkpoint
.
_load_serial
,
self
.
checkpoint
.
checkpoint_dir
,
self
.
checkpoint
.
_load_serial
,
self
.
trainer_id
,
[
"epoch_id"
,
"step_id"
])
self
.
trainer_id
,
[
"epoch_id"
,
"step_id"
])
self
.
checkpoint
.
_
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint
.
epoch_id
=
int
(
epoch_id
)
self
.
checkpoint
.
_
step_id
=
int
(
step_id
)
self
.
checkpoint
.
step_id
=
int
(
step_id
)
if
param_path
and
os
.
path
.
isdir
(
param_path
):
if
param_path
and
os
.
path
.
isdir
(
param_path
):
# load params from param_path into scope
# load params from param_path into scope
...
@@ -353,7 +353,7 @@ class Trainer(object):
...
@@ -353,7 +353,7 @@ class Trainer(object):
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
epochs
=
[
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)
epoch_id
for
epoch_id
in
range
(
num_epochs
)
if
epoch_id
>=
self
.
checkpoint
.
_
epoch_id
if
epoch_id
>=
self
.
checkpoint
.
epoch_id
]
]
for
epoch_id
in
epochs
:
for
epoch_id
in
epochs
:
event_handler
(
BeginEpochEvent
(
epoch_id
))
event_handler
(
BeginEpochEvent
(
epoch_id
))
...
@@ -363,7 +363,7 @@ class Trainer(object):
...
@@ -363,7 +363,7 @@ class Trainer(object):
return
return
if
self
.
checkpoint
and
self
.
checkpoint
.
_load_serial
\
if
self
.
checkpoint
and
self
.
checkpoint
.
_load_serial
\
and
self
.
checkpoint
.
_step_id
>=
step_id
and
self
.
checkpoint
.
_
epoch_id
==
epoch_id
:
and
self
.
checkpoint
.
step_id
>=
step_id
and
self
.
checkpoint
.
epoch_id
==
epoch_id
:
continue
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录