Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
b88f8845
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
11
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b88f8845
编写于
3月 08, 2020
作者:
L
liuyibing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'add_doc_str' into 'master'
Add docstring for WaveFlow See merge request !34
上级
8083da21
4f7ded3c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
156 addition
and
0 deletion
+156
-0
examples/waveflow/utils.py
examples/waveflow/utils.py
+49
-0
parakeet/models/waveflow/data.py
parakeet/models/waveflow/data.py
+1
-0
parakeet/models/waveflow/waveflow.py
parakeet/models/waveflow/waveflow.py
+68
-0
parakeet/models/waveflow/waveflow_modules.py
parakeet/models/waveflow/waveflow_modules.py
+38
-0
未找到文件。
examples/waveflow/utils.py
浏览文件 @
b88f8845
...
...
@@ -109,6 +109,16 @@ def add_yaml_config(config):
def
load_latest_checkpoint
(
checkpoint_dir
,
rank
=
0
):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Create checkpoint index file if not exist.
if
(
not
os
.
path
.
isfile
(
checkpoint_path
))
and
rank
==
0
:
...
...
@@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
def
save_latest_checkpoint
(
checkpoint_dir
,
iteration
):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
"checkpoint"
)
# Update the latest checkpoint index.
with
open
(
checkpoint_path
,
"w"
)
as
handle
:
...
...
@@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
iteration
=
None
,
file_path
=
None
,
dtype
=
"float32"
):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if
file_path
is
None
:
if
iteration
is
None
:
iteration
=
load_latest_checkpoint
(
checkpoint_dir
,
rank
)
...
...
@@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
def
save_latest_parameters
(
checkpoint_dir
,
iteration
,
model
,
optimizer
=
None
):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path
=
"{}/step-{}"
.
format
(
checkpoint_dir
,
iteration
)
model_dict
=
model
.
state_dict
()
dg
.
save_dygraph
(
model_dict
,
file_path
)
...
...
parakeet/models/waveflow/data.py
浏览文件 @
b88f8845
...
...
@@ -80,6 +80,7 @@ class Subset(DatasetMixin):
# whole audio for valid set
pass
else
:
# Randomly crop segment_length from audios in the training set.
# audio shape: [len]
if
audio
.
shape
[
0
]
>=
segment_length
:
max_audio_start
=
audio
.
shape
[
0
]
-
segment_length
...
...
parakeet/models/waveflow/waveflow.py
浏览文件 @
b88f8845
...
...
@@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
class
WaveFlow
():
"""Wrapper class of WaveFlow model that supports multiple APIs.
This module provides APIs for model building, training, validation,
inference, benchmarking, and saving.
Args:
config (obj): config info.
checkpoint_dir (str): path for checkpointing.
parallel (bool, optional): whether use multiple GPUs for training.
Defaults to False.
rank (int, optional): the rank of the process in a multi-process
scenario. Defaults to 0.
nranks (int, optional): the total number of processes. Defaults to 1.
tb_logger (obj, optional): logger to visualize metrics.
Defaults to None.
Returns:
WaveFlow
"""
def
__init__
(
self
,
config
,
checkpoint_dir
,
...
...
@@ -44,6 +63,15 @@ class WaveFlow():
self
.
dtype
=
"float16"
if
config
.
use_fp16
else
"float32"
def
build
(
self
,
training
=
True
):
"""Initialize the model.
Args:
training (bool, optional): Whether the model is built for training or inference.
Defaults to True.
Returns:
None
"""
config
=
self
.
config
dataset
=
LJSpeech
(
config
,
self
.
nranks
,
self
.
rank
)
self
.
trainloader
=
dataset
.
trainloader
...
...
@@ -99,6 +127,14 @@ class WaveFlow():
self
.
waveflow
=
waveflow
def
train_step
(
self
,
iteration
):
"""Train the model for one step.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self
.
waveflow
.
train
()
start_time
=
time
.
time
()
...
...
@@ -135,6 +171,14 @@ class WaveFlow():
@
dg
.
no_grad
def
valid_step
(
self
,
iteration
):
"""Run the model on the validation dataset.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self
.
waveflow
.
eval
()
tb
=
self
.
tb_logger
...
...
@@ -167,6 +211,14 @@ class WaveFlow():
@
dg
.
no_grad
def
infer
(
self
,
iteration
):
"""Run the model to synthesize audios.
Args:
iteration (int): iteration number of the loaded checkpoint.
Returns:
None
"""
self
.
waveflow
.
eval
()
config
=
self
.
config
...
...
@@ -203,6 +255,14 @@ class WaveFlow():
@
dg
.
no_grad
def
benchmark
(
self
):
"""Run the model to benchmark synthesis speed.
Args:
None
Returns:
None
"""
self
.
waveflow
.
eval
()
mels_list
=
[
mels
for
_
,
mels
in
self
.
validloader
()]
...
...
@@ -223,6 +283,14 @@ class WaveFlow():
print
(
"{} X real-time"
.
format
(
audio_time
/
syn_time
))
def
save
(
self
,
iteration
):
"""Save model checkpoint.
Args:
iteration (int): iteration number of the model to be saved.
Returns:
None
"""
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
waveflow
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/models/waveflow/waveflow_modules.py
浏览文件 @
b88f8845
...
...
@@ -293,6 +293,14 @@ class Flow(dg.Layer):
class
WaveFlowModule
(
dg
.
Layer
):
"""WaveFlow model implementation.
Args:
config (obj): model configuration parameters.
Returns:
WaveFlowModule
"""
def
__init__
(
self
,
config
):
super
(
WaveFlowModule
,
self
).
__init__
()
self
.
n_flows
=
config
.
n_flows
...
...
@@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
self
.
perms
.
append
(
perm
)
def
forward
(
self
,
audio
,
mel
):
"""Training forward pass.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with the audio are passed to a stack of Flow
modules to obtain the final latent variable z and a list of log scaling
variables, which are then passed to the WaveFlowLoss module to calculate
the negative log likelihood.
Args:
audio (obj): audio samples.
mel (obj): mel spectrograms.
Returns:
z (obj): latent variable.
log_s_list(list): list of log scaling variables.
"""
mel
=
self
.
conditioner
(
mel
)
assert
mel
.
shape
[
2
]
>=
audio
.
shape
[
1
]
# Prune out the tail of audio/mel so that time/n_group == 0.
...
...
@@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
return
z
,
log_s_list
def
synthesize
(
self
,
mel
,
sigma
=
1.0
):
"""Use model to synthesize waveform.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with initial random gaussian latent variable
are passed to a stack of Flow modules to obtain the audio output.
Args:
mel (obj): mel spectrograms.
sigma (float, optional): standard deviation of the guassian latent
variable. Defaults to 1.0.
Returns:
audio (obj): synthesized audio.
"""
if
self
.
dtype
==
"float16"
:
mel
=
fluid
.
layers
.
cast
(
mel
,
self
.
dtype
)
mel
=
self
.
conditioner
.
infer
(
mel
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录