Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
5ab3daf7
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ab3daf7
编写于
3月 26, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'io' into 'master'
change interface for io.py See merge request !49
上级
23095bf9
c845fbd5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
63 addition
and
45 deletion
+63
-45
parakeet/utils/io.py
parakeet/utils/io.py
+63
-45
未找到文件。
parakeet/utils/io.py
浏览文件 @
5ab3daf7
...
@@ -20,6 +20,11 @@ import numpy as np
...
@@ -20,6 +20,11 @@ import numpy as np
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.dygraph
as
dg
def
is_main_process
():
local_rank
=
dg
.
parallel
.
Env
().
local_rank
return
local_rank
==
0
def
add_yaml_config_to_args
(
config
):
def
add_yaml_config_to_args
(
config
):
""" Add args in yaml config to the args parsed by argparse. The argument in
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
yaml config will be overwritten by the same argument in argparse if they
...
@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
...
@@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
return
config
return
config
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
def
_load_latest_checkpoint
(
checkpoint_dir
):
"""Get the iteration number corresponding to the latest saved checkpoint
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
Args:
...
@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
...
@@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
Returns:
Returns:
int: the latest iteration number.
int: the latest iteration number.
"""
"""
checkpoint_
path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_
record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Create checkpoint index file if not exist.
# Create checkpoint index file if not exist.
if
(
not
os
.
path
.
isfile
(
checkpoint_path
))
and
rank
==
0
:
if
(
not
os
.
path
.
isfile
(
checkpoint_record
)):
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
return
0
handle
.
write
(
"model_checkpoint_path: step-0"
)
# Make sure that other process waits until checkpoint file is created
# by process 0.
while
not
os
.
path
.
isfile
(
checkpoint_path
):
time
.
sleep
(
1
)
# Fetch the latest checkpoint index.
# Fetch the latest checkpoint index.
with
open
(
checkpoint_
path
,
"r"
)
as
handle
:
with
open
(
checkpoint_
record
,
"r"
)
as
handle
:
latest_checkpoint
=
handle
.
readline
().
split
()[
-
1
]
latest_checkpoint
=
handle
.
readline
().
split
()[
-
1
]
iteration
=
int
(
latest_checkpoint
.
split
(
"-"
)[
-
1
])
iteration
=
int
(
latest_checkpoint
.
split
(
"-"
)[
-
1
])
return
iteration
return
iteration
def
save_latest
_checkpoint
(
checkpoint_dir
,
iteration
):
def
_save
_checkpoint
(
checkpoint_dir
,
iteration
):
"""Save the iteration number of the latest model to be checkpointed.
"""Save the iteration number of the latest model to be checkpointed.
Args:
Args:
...
@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
...
@@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
Returns:
Returns:
None
None
"""
"""
checkpoint_
path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
checkpoint_
record
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
# Update the latest checkpoint index.
with
open
(
checkpoint_
path
,
"w"
)
as
handle
:
with
open
(
checkpoint_
record
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
def
load_parameters
(
checkpoint_dir
,
def
load_parameters
(
model
,
rank
,
model
,
optimizer
=
None
,
optimizer
=
None
,
checkpoint_dir
=
None
,
iteration
=
None
,
iteration
=
None
,
file
_path
=
None
,
checkpoint
_path
=
None
,
dtype
=
"float32"
):
dtype
=
"float32"
):
"""Load a specific model checkpoint from disk.
"""Load a specific model checkpoint from disk.
Args:
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
iteration (int, optional): if specified, load the specific checkpoint,
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
if not specified, load the latest one. Defaults to None.
file
_path (str, optional): if specified, load the checkpoint
checkpoint
_path (str, optional): if specified, load the checkpoint
stored in the
file_path. Defaults to None.
stored in the
checkpoint_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Defaults to float32.
Returns:
Returns:
None
iteration (int): number of iterations that the loaded checkpoint has
been trained.
"""
"""
if
file_path
is
None
:
if
checkpoint_dir
is
not
None
and
checkpoint_path
is
not
None
:
raise
ValueError
(
"Load from either from (checkpoint_dir and iteration)
\n
"
"or checkpoint_path. Do not pass both."
)
if
iteration
is
not
None
and
checkpoint_dir
is
None
:
raise
ValueError
(
"When iteration is specified, checkpoint_dir should not be None"
)
if
checkpoint_dir
is
not
None
:
if
iteration
is
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
iteration
=
_load_latest_checkpoint
(
checkpoint_dir
)
if
iteration
==
0
:
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
return
"step-{}"
.
format
(
iteration
))
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
if
iteration
==
0
and
not
os
.
path
.
exists
(
checkpoint_path
):
# if step-0 exist, it is also loaded
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
file_path
)
return
iteration
if
dtype
==
"float16"
:
else
:
for
k
,
v
in
model_dict
.
items
():
# checkpoint is not None
if
"conv2d_transpose"
in
k
:
iteration
=
int
(
os
.
path
.
basename
(
checkpoint_path
).
split
(
"-"
)[
-
1
])
model_dict
[
k
]
=
v
.
astype
(
"float32"
)
else
:
local_rank
=
dg
.
parallel
.
Env
().
local_rank
model_dict
[
k
]
=
v
.
astype
(
dtype
)
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
checkpoint_path
)
# cast to desired data type
for
k
,
v
in
model_dict
.
items
():
model_dict
[
k
]
=
v
.
astype
(
dtype
)
model
.
set_dict
(
model_dict
)
model
.
set_dict
(
model_dict
)
print
(
"[checkpoint] Rank {}: loaded model from {}"
.
format
(
rank
,
file_path
))
print
(
"[checkpoint] Rank {}: loaded model from {}.pdparams"
.
format
(
local_rank
,
checkpoint_path
))
if
optimizer
and
optimizer_dict
:
if
optimizer
and
optimizer_dict
:
optimizer
.
set_dict
(
optimizer_dict
)
optimizer
.
set_dict
(
optimizer_dict
)
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}
"
.
format
(
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}
.pdopt"
.
rank
,
file
_path
))
format
(
local_rank
,
checkpoint
_path
))
return
iteration
def
save_latest_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
def
save_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
"""Checkpoint the latest trained model parameters.
"""Checkpoint the latest trained model parameters.
Args:
Args:
...
@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
...
@@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
Returns:
Returns:
None
None
"""
"""
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"step-{}"
.
format
(
iteration
)
)
model_dict
=
model
.
state_dict
()
model_dict
=
model
.
state_dict
()
dg
.
save_dygraph
(
model_dict
,
file
_path
)
dg
.
save_dygraph
(
model_dict
,
checkpoint
_path
)
print
(
"[checkpoint] Saved model to {}
"
.
format
(
file
_path
))
print
(
"[checkpoint] Saved model to {}
.pdparams"
.
format
(
checkpoint
_path
))
if
optimizer
:
if
optimizer
:
opt_dict
=
optimizer
.
state_dict
()
opt_dict
=
optimizer
.
state_dict
()
dg
.
save_dygraph
(
opt_dict
,
file_path
)
dg
.
save_dygraph
(
opt_dict
,
checkpoint_path
)
print
(
"[checkpoint] Saved optimzier state to {}"
.
format
(
file_path
))
print
(
"[checkpoint] Saved optimzier state to {}.pdopt"
.
format
(
checkpoint_path
))
_save_checkpoint
(
checkpoint_dir
,
iteration
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录