Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
424c16a6
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
424c16a6
编写于
2月 27, 2020
作者:
C
chenfeiyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
staged clarinet
上级
a0128254
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
287 addition
and
5 deletion
+287
-5
examples/clarinet/configs/clarinet_ljspeech.yaml
examples/clarinet/configs/clarinet_ljspeech.yaml
+52
-0
examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml
examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml
+0
-1
examples/wavenet/configs/wavenet_single_gaussian.yaml
examples/wavenet/configs/wavenet_single_gaussian.yaml
+0
-1
examples/wavenet/configs/wavenet_softmax.yaml
examples/wavenet/configs/wavenet_softmax.yaml
+0
-1
examples/wavenet/utils.py
examples/wavenet/utils.py
+1
-1
parakeet/data/dataset.py
parakeet/data/dataset.py
+1
-1
parakeet/models/clarinet/__init__.py
parakeet/models/clarinet/__init__.py
+16
-0
parakeet/models/clarinet/net.py
parakeet/models/clarinet/net.py
+169
-0
parakeet/models/clarinet/utils.py
parakeet/models/clarinet/utils.py
+48
-0
未找到文件。
examples/clarinet/configs/clarinet_ljspeech.yaml
0 → 100644
浏览文件 @
424c16a6
data
:
batch_size
:
4
train_clip_seconds
:
0.5
sample_rate
:
22050
hop_length
:
256
win_length
:
1024
n_fft
:
2048
n_mels
:
80
valid_size
:
16
conditioner
:
upsampling_factors
:
[
16
,
16
]
teacher
:
n_loop
:
10
n_layer
:
3
filter_size
:
2
residual_channels
:
128
loss_type
:
"
mog"
output_dim
:
3
log_scale_min
:
-9
student
:
n_loops
:
[
10
,
10
,
10
,
10
,
10
,
10
]
n_layers
:
[
1
,
1
,
1
,
1
,
1
,
1
]
filter_size
:
3
residual_channels
:
64
log_scale_min
:
-7
stft
:
n_fft
:
2048
win_length
:
1024
hop_length
:
256
loss
:
lmd
:
4
train
:
learning_rate
:
0.0005
anneal_rate
:
0.5
anneal_interval
:
200000
gradient_max_norm
:
100.0
checkpoint_interval
:
10
eval_interval
:
10
max_iterations
:
2000000
examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml
浏览文件 @
424c16a6
data
:
root
:
"
/workspace/datasets/LJSpeech-1.1/"
batch_size
:
4
train_clip_seconds
:
0.5
sample_rate
:
22050
...
...
examples/wavenet/configs/wavenet_single_gaussian.yaml
浏览文件 @
424c16a6
data
:
root
:
"
/workspace/datasets/LJSpeech-1.1/"
batch_size
:
4
train_clip_seconds
:
0.5
sample_rate
:
22050
...
...
examples/wavenet/configs/wavenet_softmax.yaml
浏览文件 @
424c16a6
data
:
root
:
"
/workspace/datasets/LJSpeech-1.1/"
batch_size
:
4
train_clip_seconds
:
0.5
sample_rate
:
22050
...
...
examples/wavenet/utils.py
浏览文件 @
424c16a6
...
...
@@ -56,7 +56,7 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
audio_clips
,
mel_specs
,
audio_starts
=
batch
wav_var
=
model
.
synthesis
(
mel_specs
)
wav_np
=
wav_var
.
numpy
()[
0
]
sf
.
write
(
wav_np
,
path
,
samplerate
=
sample_rate
)
sf
.
write
(
path
,
wav_np
,
samplerate
=
sample_rate
)
print
(
"generated {}"
.
format
(
path
))
...
...
parakeet/data/dataset.py
浏览文件 @
424c16a6
...
...
@@ -134,7 +134,7 @@ class SliceDataset(DatasetMixin):
format
(
len
(
order
),
len
(
dataset
)))
self
.
_order
=
order
def
len
(
self
):
def
__len__
(
self
):
return
self
.
_size
def
get_example
(
self
,
i
):
...
...
parakeet/models/clarinet/__init__.py
0 → 100644
浏览文件 @
424c16a6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.net
import
*
from
.parallel_wavenet
import
*
\ No newline at end of file
parakeet/models/clarinet/net.py
0 → 100644
浏览文件 @
424c16a6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
itertools
import
numpy
as
np
from
scipy
import
signal
from
tqdm
import
trange
import
paddle.fluid.layers
as
F
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.initializer
as
I
import
paddle.fluid.layers.distributions
as
D
from
parakeet.modules.weight_norm
import
Conv2DTranspose
from
parakeet.models.wavenet
import
crop
,
WaveNet
,
UpsampleNet
from
parakeet.models.clarinet.parallel_wavenet
import
ParallelWaveNet
from
parakeet.models.clarinet.utils
import
conv2d
# Gaussian IAF model
class
Clarinet
(
dg
.
Layer
):
def
__init__
(
self
,
encoder
,
teacher
,
student
,
stft
,
min_log_scale
=-
6.0
,
lmd
=
4.0
):
super
(
Clarinet
,
self
).
__init__
()
self
.
lmd
=
lmd
self
.
encoder
=
encoder
self
.
teacher
=
teacher
self
.
student
=
student
self
.
min_log_scale
=
min_log_scale
self
.
stft
=
stft
def
forward
(
self
,
audio
,
mel
,
audio_start
,
clip_kl
=
True
):
"""Compute loss for a distill model
Arguments:
audio {Variable} -- shape(batch_size, time_steps), target waveform.
mel {Variable} -- shape(batch_size, condition_dim, time_steps // hop_length), original mel spectrogram, not upsampled yet.
audio_starts {Variable} -- shape(batch_size, ), the index of the start sample.
clip_kl (bool) -- whether to clip kl divergence if it is greater than 10.0.
Returns:
Variable -- shape(1,), loss
"""
batch_size
,
audio_length
=
audio
.
shape
# audio clip's length
z
=
F
.
gaussian_random
(
audio
.
shape
)
condition
=
self
.
encoder
(
mel
)
# (B, C, T)
condition_slice
=
crop
(
condition
,
audio_start
,
audio_length
)
x
,
s_means
,
s_scales
=
self
.
student
(
z
,
condition_slice
)
# all [0: T]
s_means
=
s_means
[:,
1
:]
# (B, T-1), time steps [1: T]
s_scales
=
s_scales
[:,
1
:]
# (B, T-1), time steps [1: T]
s_clipped_scales
=
F
.
clip
(
s_scales
,
self
.
min_log_scale
,
100.
)
# teacher outputs single gaussian
y
=
self
.
teacher
(
x
[:,
:
-
1
],
condition_slice
[:,
:,
1
:])
_
,
t_means
,
t_scales
=
F
.
split
(
y
,
3
,
-
1
)
# time steps [1: T]
t_means
=
F
.
squeeze
(
t_means
,
[
-
1
])
# (B, T-1), time steps [1: T]
t_scales
=
F
.
squeeze
(
t_scales
,
[
-
1
])
# (B, T-1), time steps [1: T]
t_clipped_scales
=
F
.
clip
(
t_scales
,
self
.
min_log_scale
,
100.
)
s_distribution
=
D
.
Normal
(
s_means
,
F
.
exp
(
s_clipped_scales
))
t_distribution
=
D
.
Normal
(
t_means
,
F
.
exp
(
t_clipped_scales
))
# kl divergence loss, so we only need to sample once? no MC
kl
=
s_distribution
.
kl_divergence
(
t_distribution
)
if
clip_kl
:
kl
=
F
.
clip
(
kl
,
-
100.
,
10.
)
# context size dropped
kl
=
F
.
reduce_mean
(
kl
[:,
self
.
teacher
.
context_size
:])
# major diff here
regularization
=
F
.
mse_loss
(
t_scales
[:,
self
.
teacher
.
context_size
:],
s_scales
[:,
self
.
teacher
.
context_size
:])
# introduce information from real target
spectrogram_frame_loss
=
F
.
mse_loss
(
self
.
stft
.
magnitude
(
audio
),
self
.
stft
.
magnitude
(
x
))
loss
=
kl
+
self
.
lmd
*
regularization
+
spectrogram_frame_loss
loss_dict
=
{
"loss"
:
loss
,
"kl_divergence"
:
kl
,
"regularization"
:
regularization
,
"stft_loss"
:
spectrogram_frame_loss
}
return
loss_dict
@
dg
.
no_grad
def
synthesis
(
self
,
mel
):
"""Synthesize waveform conditioned on the mel spectrogram.
Arguments:
mel {Variable} -- shape(batch_size, frequqncy_bands, frames)
Returns:
Variable -- shape(batch_size, frames * upsample_factor)
"""
condition
=
self
.
encoder
(
mel
)
samples_shape
=
(
condition
.
shape
[
0
],
condition
.
shape
[
-
1
])
z
=
F
.
gaussian_random
(
samples_shape
)
x
,
s_means
,
s_scales
=
self
.
student
(
z
,
condition
)
return
x
class
STFT
(
dg
.
Layer
):
def
__init__
(
self
,
n_fft
,
hop_length
,
win_length
,
window
=
"hanning"
):
super
(
STFT
,
self
).
__init__
()
self
.
hop_length
=
hop_length
self
.
n_bin
=
1
+
n_fft
//
2
self
.
n_fft
=
n_fft
# calculate window
window
=
signal
.
get_window
(
window
,
win_length
)
if
n_fft
!=
win_length
:
pad
=
(
n_fft
-
win_length
)
//
2
window
=
np
.
pad
(
window
,
((
pad
,
pad
),
),
'constant'
)
# calculate weights
r
=
np
.
arange
(
0
,
n_fft
)
M
=
np
.
expand_dims
(
r
,
-
1
)
*
np
.
expand_dims
(
r
,
0
)
w_real
=
np
.
reshape
(
window
*
np
.
cos
(
2
*
np
.
pi
*
M
/
n_fft
)[:
self
.
n_bin
],
(
self
.
n_bin
,
1
,
1
,
self
.
n_fft
)).
astype
(
"float32"
)
w_imag
=
np
.
reshape
(
window
*
np
.
sin
(
-
2
*
np
.
pi
*
M
/
n_fft
)[:
self
.
n_bin
],
(
self
.
n_bin
,
1
,
1
,
self
.
n_fft
)).
astype
(
"float32"
)
w
=
np
.
concatenate
([
w_real
,
w_imag
],
axis
=
0
)
self
.
weight
=
dg
.
to_variable
(
w
)
def
forward
(
self
,
x
):
# x(batch_size, time_steps)
# pad it first with reflect mode
pad_start
=
F
.
reverse
(
x
[:,
1
:
1
+
self
.
n_fft
//
2
],
axis
=
1
)
pad_stop
=
F
.
reverse
(
x
[:,
-
(
1
+
self
.
n_fft
//
2
):
-
1
],
axis
=
1
)
x
=
F
.
concat
([
pad_start
,
x
,
pad_stop
],
axis
=-
1
)
# to BC1T, C=1
x
=
F
.
unsqueeze
(
x
,
axes
=
[
1
,
2
])
out
=
conv2d
(
x
,
self
.
weight
,
stride
=
(
1
,
self
.
hop_length
))
real
,
imag
=
F
.
split
(
out
,
2
,
dim
=
1
)
# BC1T
return
real
,
imag
def
power
(
self
,
x
):
real
,
imag
=
self
(
x
)
power
=
real
**
2
+
imag
**
2
return
power
def
magnitude
(
self
,
x
):
power
=
self
.
power
(
x
)
magnitude
=
F
.
sqrt
(
power
)
return
magnitude
parakeet/models/clarinet/utils.py
0 → 100644
浏览文件 @
424c16a6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
fluid
from
paddle.fluid.core
import
ops
@
fluid
.
framework
.
dygraph_only
def
conv2d
(
input
,
weight
,
stride
=
(
1
,
1
),
padding
=
((
0
,
0
),
(
0
,
0
)),
dilation
=
(
1
,
1
),
groups
=
1
,
use_cudnn
=
True
,
data_format
=
"NCHW"
):
padding
=
tuple
(
pad
for
pad_dim
in
padding
for
pad
in
pad_dim
)
inputs
=
{
'Input'
:
[
input
],
'Filter'
:
[
weight
],
}
attrs
=
{
'strides'
:
stride
,
'paddings'
:
padding
,
'dilations'
:
dilation
,
'groups'
:
groups
,
'use_cudnn'
:
use_cudnn
,
'use_mkldnn'
:
False
,
'fuse_relu_before_depthwise_conv'
:
False
,
"padding_algorithm"
:
"EXPLICIT"
,
"data_format"
:
data_format
,
}
outputs
=
ops
.
conv2d
(
inputs
,
attrs
)
out
=
outputs
[
"Output"
][
0
]
return
out
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录