Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
64790853
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
14
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
64790853
编写于
3月 22, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify save & load interfaces
上级
be70b41f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
173 addition
and
145 deletion
+173
-145
examples/waveflow/benchmark.py
examples/waveflow/benchmark.py
+2
-1
examples/waveflow/synthesis.py
examples/waveflow/synthesis.py
+3
-2
examples/waveflow/train.py
examples/waveflow/train.py
+3
-2
examples/waveflow/utils.py
examples/waveflow/utils.py
+0
-135
parakeet/models/waveflow/waveflow.py
parakeet/models/waveflow/waveflow.py
+7
-5
parakeet/utils/io.py
parakeet/utils/io.py
+158
-0
未找到文件。
examples/waveflow/benchmark.py
浏览文件 @
64790853
...
@@ -22,6 +22,7 @@ import paddle.fluid.dygraph as dg
...
@@ -22,6 +22,7 @@ import paddle.fluid.dygraph as dg
from
paddle
import
fluid
from
paddle
import
fluid
import
utils
import
utils
from
parakeet.utils
import
io
from
parakeet.models.waveflow
import
WaveFlow
from
parakeet.models.waveflow
import
WaveFlow
...
@@ -98,5 +99,5 @@ if __name__ == "__main__":
...
@@ -98,5 +99,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
# the preceding update will be overwritten by the following one.
config
=
parser
.
parse_args
()
config
=
parser
.
parse_args
()
config
=
utils
.
add_yaml_config
(
config
)
config
=
io
.
add_yaml_config_to_args
(
config
)
benchmark
(
config
)
benchmark
(
config
)
examples/waveflow/synthesis.py
浏览文件 @
64790853
...
@@ -23,6 +23,7 @@ from paddle import fluid
...
@@ -23,6 +23,7 @@ from paddle import fluid
import
utils
import
utils
from
parakeet.models.waveflow
import
WaveFlow
from
parakeet.models.waveflow
import
WaveFlow
from
parakeet.utils
import
io
def
add_options_to_parser
(
parser
):
def
add_options_to_parser
(
parser
):
...
@@ -96,7 +97,7 @@ def synthesize(config):
...
@@ -96,7 +97,7 @@ def synthesize(config):
# Obtain the current iteration.
# Obtain the current iteration.
if
config
.
checkpoint
is
None
:
if
config
.
checkpoint
is
None
:
if
config
.
iteration
is
None
:
if
config
.
iteration
is
None
:
iteration
=
utils
.
load_latest_checkpoint
(
checkpoint_dir
)
iteration
=
io
.
load_latest_checkpoint
(
checkpoint_dir
)
else
:
else
:
iteration
=
config
.
iteration
iteration
=
config
.
iteration
else
:
else
:
...
@@ -117,5 +118,5 @@ if __name__ == "__main__":
...
@@ -117,5 +118,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
# the preceding update will be overwritten by the following one.
config
=
parser
.
parse_args
()
config
=
parser
.
parse_args
()
config
=
utils
.
add_yaml_config
(
config
)
config
=
io
.
add_yaml_config_to_args
(
config
)
synthesize
(
config
)
synthesize
(
config
)
examples/waveflow/train.py
浏览文件 @
64790853
...
@@ -25,6 +25,7 @@ from paddle import fluid
...
@@ -25,6 +25,7 @@ from paddle import fluid
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
import
utils
import
utils
from
parakeet.utils
import
io
from
parakeet.models.waveflow
import
WaveFlow
from
parakeet.models.waveflow
import
WaveFlow
...
@@ -104,7 +105,7 @@ def train(config):
...
@@ -104,7 +105,7 @@ def train(config):
# Obtain the current iteration.
# Obtain the current iteration.
if
config
.
checkpoint
is
None
:
if
config
.
checkpoint
is
None
:
if
config
.
iteration
is
None
:
if
config
.
iteration
is
None
:
iteration
=
utils
.
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
iteration
=
io
.
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
else
:
else
:
iteration
=
config
.
iteration
iteration
=
config
.
iteration
else
:
else
:
...
@@ -140,7 +141,7 @@ if __name__ == "__main__":
...
@@ -140,7 +141,7 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
# the preceding update will be overwritten by the following one.
config
=
parser
.
parse_args
()
config
=
parser
.
parse_args
()
config
=
utils
.
add_yaml_config
(
config
)
config
=
io
.
add_yaml_config_to_args
(
config
)
# Force to use fp32 in model training
# Force to use fp32 in model training
vars
(
config
)[
"use_fp16"
]
=
False
vars
(
config
)[
"use_fp16"
]
=
False
train
(
config
)
train
(
config
)
examples/waveflow/utils.py
浏览文件 @
64790853
...
@@ -12,14 +12,7 @@
...
@@ -12,14 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
itertools
import
os
import
time
import
argparse
import
argparse
import
ruamel.yaml
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
def
str2bool
(
v
):
def
str2bool
(
v
):
...
@@ -95,131 +88,3 @@ def add_config_options_to_parser(parser):
...
@@ -95,131 +88,3 @@ def add_config_options_to_parser(parser):
'--kernel_w'
,
type
=
int
,
help
=
"width of the kernel in the conv2d layer"
)
'--kernel_w'
,
type
=
int
,
help
=
"width of the kernel in the conv2d layer"
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
"Path to the config file."
)
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
"Path to the config file."
)
def
add_yaml_config
(
config
):
with
open
(
config
.
config
,
'rt'
)
as
f
:
yaml_cfg
=
ruamel
.
yaml
.
safe_load
(
f
)
cfg_vars
=
vars
(
config
)
for
k
,
v
in
yaml_cfg
.
items
():
if
k
in
cfg_vars
and
cfg_vars
[
k
]
is
not
None
:
continue
cfg_vars
[
k
]
=
v
return
config
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Create checkpoint index file if not exist.
if
(
not
os
.
path
.
isfile
(
checkpoint_path
))
and
rank
==
0
:
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-0"
)
# Make sure that other process waits until checkpoint file is created
# by process 0.
while
not
os
.
path
.
isfile
(
checkpoint_path
):
time
.
sleep
(
1
)
# Fetch the latest checkpoint index.
with
open
(
checkpoint_path
,
"r"
)
as
handle
:
latest_checkpoint
=
handle
.
readline
().
split
()[
-
1
]
iteration
=
int
(
latest_checkpoint
.
split
(
"-"
)[
-
1
])
return
iteration
def
save_latest_checkpoint
(
checkpoint_dir
,
iteration
):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
def
load_parameters
(
checkpoint_dir
,
rank
,
model
,
optimizer
=
None
,
iteration
=
None
,
file_path
=
None
,
dtype
=
"float32"
):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if
file_path
is
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
if
iteration
==
0
:
return
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
file_path
)
if
dtype
==
"float16"
:
for
k
,
v
in
model_dict
.
items
():
if
"conv2d_transpose"
in
k
:
model_dict
[
k
]
=
v
.
astype
(
"float32"
)
else
:
model_dict
[
k
]
=
v
.
astype
(
dtype
)
model
.
set_dict
(
model_dict
)
print
(
"[checkpoint] Rank {}: loaded model from {}"
.
format
(
rank
,
file_path
))
if
optimizer
and
optimizer_dict
:
optimizer
.
set_dict
(
optimizer_dict
)
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}"
.
format
(
rank
,
file_path
))
def
save_latest_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
=
model
.
state_dict
()
dg
.
save_dygraph
(
model_dict
,
file_path
)
print
(
"[checkpoint] Saved model to {}"
.
format
(
file_path
))
if
optimizer
:
opt_dict
=
optimizer
.
state_dict
()
dg
.
save_dygraph
(
opt_dict
,
file_path
)
print
(
"[checkpoint] Saved optimzier state to {}"
.
format
(
file_path
))
parakeet/models/waveflow/waveflow.py
浏览文件 @
64790853
...
@@ -22,6 +22,7 @@ from paddle import fluid
...
@@ -22,6 +22,7 @@ from paddle import fluid
from
scipy.io.wavfile
import
write
from
scipy.io.wavfile
import
write
import
utils
import
utils
from
parakeet.utils
import
io
from
parakeet.modules
import
weight_norm
from
parakeet.modules
import
weight_norm
from
.data
import
LJSpeech
from
.data
import
LJSpeech
from
.waveflow_modules
import
WaveFlowLoss
,
WaveFlowModule
from
.waveflow_modules
import
WaveFlowLoss
,
WaveFlowModule
...
@@ -47,6 +48,7 @@ class WaveFlow():
...
@@ -47,6 +48,7 @@ class WaveFlow():
Returns:
Returns:
WaveFlow
WaveFlow
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
config
,
config
,
checkpoint_dir
,
checkpoint_dir
,
...
@@ -91,7 +93,7 @@ class WaveFlow():
...
@@ -91,7 +93,7 @@ class WaveFlow():
parameter_list
=
waveflow
.
parameters
())
parameter_list
=
waveflow
.
parameters
())
# Load parameters.
# Load parameters.
utils
.
load_parameters
(
io
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
checkpoint_dir
,
self
.
rank
,
self
.
rank
,
waveflow
,
waveflow
,
...
@@ -111,7 +113,7 @@ class WaveFlow():
...
@@ -111,7 +113,7 @@ class WaveFlow():
else
:
else
:
# Load parameters.
# Load parameters.
utils
.
load_parameters
(
io
.
load_parameters
(
self
.
checkpoint_dir
,
self
.
checkpoint_dir
,
self
.
rank
,
self
.
rank
,
waveflow
,
waveflow
,
...
@@ -291,6 +293,6 @@ class WaveFlow():
...
@@ -291,6 +293,6 @@ class WaveFlow():
Returns:
Returns:
None
None
"""
"""
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
io
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
waveflow
,
self
.
optimizer
)
self
.
waveflow
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
io
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/utils/io.py
0 → 100644
浏览文件 @
64790853
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
import
ruamel.yaml
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
def
add_yaml_config_to_args
(
config
):
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
are both valid.
Args:
config (args): the args returned by `argparse.ArgumentParser().parse_args()`
Returns:
config: the args added yaml config.
"""
with
open
(
config
.
config
,
'rt'
)
as
f
:
yaml_cfg
=
ruamel
.
yaml
.
safe_load
(
f
)
cfg_vars
=
vars
(
config
)
for
k
,
v
in
yaml_cfg
.
items
():
if
k
in
cfg_vars
and
cfg_vars
[
k
]
is
not
None
:
continue
cfg_vars
[
k
]
=
v
return
config
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Create checkpoint index file if not exist.
if
(
not
os
.
path
.
isfile
(
checkpoint_path
))
and
rank
==
0
:
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-0"
)
# Make sure that other process waits until checkpoint file is created
# by process 0.
while
not
os
.
path
.
isfile
(
checkpoint_path
):
time
.
sleep
(
1
)
# Fetch the latest checkpoint index.
with
open
(
checkpoint_path
,
"r"
)
as
handle
:
latest_checkpoint
=
handle
.
readline
().
split
()[
-
1
]
iteration
=
int
(
latest_checkpoint
.
split
(
"-"
)[
-
1
])
return
iteration
def
save_latest_checkpoint
(
checkpoint_dir
,
iteration
):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
handle
.
write
(
"model_checkpoint_path: step-{}"
.
format
(
iteration
))
def
load_parameters
(
checkpoint_dir
,
rank
,
model
,
optimizer
=
None
,
iteration
=
None
,
file_path
=
None
,
dtype
=
"float32"
):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if
file_path
is
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
if
iteration
==
0
:
return
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
,
optimizer_dict
=
dg
.
load_dygraph
(
file_path
)
if
dtype
==
"float16"
:
for
k
,
v
in
model_dict
.
items
():
if
"conv2d_transpose"
in
k
:
model_dict
[
k
]
=
v
.
astype
(
"float32"
)
else
:
model_dict
[
k
]
=
v
.
astype
(
dtype
)
model
.
set_dict
(
model_dict
)
print
(
"[checkpoint] Rank {}: loaded model from {}"
.
format
(
rank
,
file_path
))
if
optimizer
and
optimizer_dict
:
optimizer
.
set_dict
(
optimizer_dict
)
print
(
"[checkpoint] Rank {}: loaded optimizer state from {}"
.
format
(
rank
,
file_path
))
def
save_latest_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
=
model
.
state_dict
()
dg
.
save_dygraph
(
model_dict
,
file_path
)
print
(
"[checkpoint] Saved model to {}"
.
format
(
file_path
))
if
optimizer
:
opt_dict
=
optimizer
.
state_dict
()
dg
.
save_dygraph
(
opt_dict
,
file_path
)
print
(
"[checkpoint] Saved optimzier state to {}"
.
format
(
file_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录