Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
c845fbd5
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c845fbd5
编写于
3月 26, 2020
作者:
C
chenfeiyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change interface for io.py
上级
64790853
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
63 addition
and
45 deletion
+63
-45
parakeet/utils/io.py
parakeet/utils/io.py
+63
-45
未找到文件。
parakeet/utils/io.py
浏览文件 @
c845fbd5
...
...
@@ -20,6 +20,11 @@ import numpy as np
import
paddle.fluid.dygraph
as
dg
def
is_main_process
():
local_rank
=
dg
.
parallel
.
Env
().
local_rank
return
local_rank
==
0
def
add_yaml_config_to_args
(
config
):
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
...
...
@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
return
config
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
def
_load_latest_checkpoint
(
checkpoint_dir
):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
...
...
@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
Returns:
int: the latest iteration number.
"""
checkpoint_
path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_
record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Create checkpoint index file if not exist.
if
(
not
os
.
path
.
isfile
(
checkpoint_path
))
and
rank
==
0
:
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-0"
)
# Make sure that other process waits until checkpoint file is created
# by process 0.
while
not
os
.
path
.
isfile
(
checkpoint_path
):
time
.
sleep
(
1
)
if
(
not
os
.
path
.
isfile
(
checkpoint_record
)):
return
0
# Fetch the latest checkpoint index.
with
open
(
checkpoint_
path
,
"r"
)
as
handle
:
with
open
(
checkpoint_
record
,
"r"
)
as
handle
:
latest_checkpoint
=
handle
.
readline
().
split
()[
-
1
]
iteration
=
int
(
latest_checkpoint
.
split
(
"-"
)[
-
1
])
return
iteration
def
save_latest
_checkpoint
(
checkpoint_dir
,
iteration
):
def
_save
_checkpoint
(
checkpoint_dir
,
iteration
):
"""Save the iteration number of the latest model to be checkpointed.
Args:
...
...
@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
Returns:
None
"""
checkpoint_
path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_
record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
with
open
(
checkpoint_
path
,
"w"
)
as
handle
:
with
open
(
checkpoint_
record
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
def
load_parameters
(
checkpoint_dir
,
rank
,
model
,
def
load_parameters
(
model
,
optimizer
=
None
,
checkpoint_dir
=
None
,
iteration
=
None
,
file
_path
=
None
,
checkpoint
_path
=
None
,
dtype
=
"float32"
):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file
_path (str, optional): if specified, load the checkpoint
stored in the
file_path. Defaults to None.
checkpoint
_path (str, optional): if specified, load the checkpoint
stored in the
checkpoint_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
iteration (int): number of iterations that the loaded checkpoint has
been trained.
"""
if
file_path
is
None
:
if
checkpoint_dir
is
not
None
and
checkpoint_path
is
not
None
:
raise
ValueError
(
"Load from either from (checkpoint_dir and iteration)
\n
"
"or checkpoint_path. Do not pass both."
)
if
iteration
is
not
None
and
checkpoint_dir
is
None
:
raise
ValueError
(
"When iteration is specified, checkpoint_dir should not be None"
)
if
checkpoint_dir
is
not
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
if
iteration
==
0
:
return
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
file_path
)
if
dtype
==
"float16"
:
for
k
,
v
in
model_dict
.
items
():
if
"conv2d_transpose"
in
k
:
model_dict
[
k
]
=
v
.
astype
(
"float32"
)
else
:
model_dict
[
k
]
=
v
.
astype
(
dtype
)
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"step-{}"
.
format
(
iteration
))
if
iteration
==
0
and
not
os
.
path
.
exists
(
checkpoint_path
):
# if step-0 exist, it is also loaded
return
iteration
else
:
# checkpoint is not None
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
"-"
)[
-
1
])
local_rank
=
dg
.
parallel
.
Env
().
local_rank
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
checkpoint_path
)
# cast to desired data type
for
k
,
v
in
model_dict
.
items
():
model_dict
[
k
]
=
v
.
astype
(
dtype
)
model
.
set_dict
(
model_dict
)
print
(
"[checkpoint] Rank {}: loaded model from {}"
.
format
(
rank
,
file_path
))
print
(
"[checkpoint] Rank {}: loaded model from {}.pdparams"
.
format
(
local_rank
,
checkpoint_path
))
if
optimizer
and
optimizer_dict
:
optimizer
.
set_dict
(
optimizer_dict
)
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}
"
.
format
(
rank
,
file
_path
))
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}
.pdopt"
.
format
(
local_rank
,
checkpoint
_path
))
return
iteration
def
save_latest_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
def
save_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
...
...
@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
Returns:
None
"""
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"step-{}"
.
format
(
iteration
)
)
model_dict
=
model
.
state_dict
()
dg
.
save_dygraph
(
model_dict
,
file
_path
)
print
(
"[checkpoint] Saved model to {}
"
.
format
(
file
_path
))
dg
.
save_dygraph
(
model_dict
,
checkpoint
_path
)
print
(
"[checkpoint] Saved model to {}
.pdparams"
.
format
(
checkpoint
_path
))
if
optimizer
:
opt_dict
=
optimizer
.
state_dict
()
dg
.
save_dygraph
(
opt_dict
,
file_path
)
print
(
"[checkpoint] Saved optimzier state to {}"
.
format
(
file_path
))
dg
.
save_dygraph
(
opt_dict
,
checkpoint_path
)
print
(
"[checkpoint] Saved optimzier state to {}.pdopt"
.
format
(
checkpoint_path
))
_save_checkpoint
(
checkpoint_dir
,
iteration
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录