Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
98841ee4
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
11
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98841ee4
编写于
12月 02, 2019
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code
上级
b15c3134
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
314 addition
and
1322 deletion
+314
-1322
parakeet/data/sampler.py
parakeet/data/sampler.py
+30
-1
parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
...models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
+32
-0
parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
...keet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
+31
-0
parakeet/models/wavenet/data.py
parakeet/models/wavenet/data.py
+3
-34
parakeet/models/wavenet/ops.py
parakeet/models/wavenet/ops.py
+0
-249
parakeet/models/wavenet/wavenet.py
parakeet/models/wavenet/wavenet.py
+8
-20
parakeet/models/wavenet/wavenet_modules.py
parakeet/models/wavenet/wavenet_modules.py
+56
-98
parakeet/models/wavenet/weight_norm.py
parakeet/models/wavenet/weight_norm.py
+0
-920
parakeet/modules/modules.py
parakeet/modules/modules.py
+154
-0
未找到文件。
parakeet/data/sampler.py
浏览文件 @
98841ee4
...
...
@@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler):
return
self
.
num_samples
class
DistributedSampler
(
Sampler
):
def
__init__
(
self
,
dataset_size
,
num_trainers
,
rank
,
shuffle
=
True
):
self
.
dataset_size
=
dataset_size
self
.
num_trainers
=
num_trainers
self
.
rank
=
rank
self
.
num_samples
=
int
(
np
.
ceil
(
dataset_size
/
num_trainers
))
self
.
total_size
=
self
.
num_samples
*
num_trainers
assert
self
.
total_size
>=
self
.
dataset_size
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
indices
=
list
(
range
(
self
.
dataset_size
))
if
self
.
shuffle
:
random
.
shuffle
(
indices
)
# Append extra samples to make it evenly distributed on all trainers.
indices
+=
indices
[:(
self
.
total_size
-
self
.
dataset_size
)]
assert
len
(
indices
)
==
self
.
total_size
# Subset samples for each trainer.
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_trainers
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
class
BatchSampler
(
Sampler
):
r
"""Wraps another sampler to yield a mini-batch of indices.
Args:
...
...
@@ -206,4 +235,4 @@ class BatchSampler(Sampler):
if
self
.
drop_last
:
return
len
(
self
.
sampler
)
//
self
.
batch_size
else
:
return
(
len
(
self
.
sampler
)
+
self
.
batch_size
-
1
)
//
self
.
batch_size
\ No newline at end of file
return
(
len
(
self
.
sampler
)
+
self
.
batch_size
-
1
)
//
self
.
batch_size
parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
0 → 100644
浏览文件 @
98841ee4
valid_size
:
16
train_clip_second
:
0.5
sample_rate
:
22050
fft_window_shift
:
256
fft_window_size
:
1024
fft_size
:
2048
mel_bands
:
80
seed
:
1
batch_size
:
8
test_every
:
2000
save_every
:
10000
max_iterations
:
2000000
layers
:
30
kernel_width
:
2
dilation_block
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
residual_channels
:
128
skip_channels
:
128
loss_type
:
mix-gaussian-pdf
num_mixtures
:
10
log_scale_min
:
-9.0
conditioner
:
filter_sizes
:
[[
32
,
3
],
[
32
,
3
]]
upsample_factors
:
[
16
,
16
]
learning_rate
:
0.001
gradient_max_norm
:
100.0
anneal
:
every
:
200000
rate
:
0.5
parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
0 → 100644
浏览文件 @
98841ee4
valid_size
:
16
train_clip_second
:
0.5
sample_rate
:
22050
fft_window_shift
:
256
fft_window_size
:
1024
fft_size
:
2048
mel_bands
:
80
seed
:
1
batch_size
:
8
test_every
:
2000
save_every
:
10000
max_iterations
:
2000000
layers
:
30
kernel_width
:
2
dilation_block
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
residual_channels
:
128
skip_channels
:
128
loss_type
:
softmax
num_channels
:
2048
conditioner
:
filter_sizes
:
[[
32
,
3
],
[
32
,
3
]]
upsample_factors
:
[
16
,
16
]
learning_rate
:
0.001
gradient_max_norm
:
100.0
anneal
:
every
:
200000
rate
:
0.5
parakeet/models/wavenet/data.py
浏览文件 @
98841ee4
import
math
import
os
import
random
import
librosa
...
...
@@ -9,7 +7,7 @@ from paddle import fluid
import
utils
from
parakeet.datasets
import
ljspeech
from
parakeet.data
import
dataset
from
parakeet.data.sampler
import
Sampler
,
BatchSampler
,
Sequential
Sampler
from
parakeet.data.sampler
import
DistributedSampler
,
Batch
Sampler
from
parakeet.data.datacargo
import
DataCargo
...
...
@@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech):
self
.
fft_window_shift
=
config
.
fft_window_shift
# Calculate context frames.
frames_per_second
=
config
.
sample_rate
//
self
.
fft_window_shift
train_clip_frames
=
int
(
math
.
ceil
(
train_clip_frames
=
int
(
np
.
ceil
(
config
.
train_clip_second
*
frames_per_second
))
context_frames
=
config
.
context_size
//
self
.
fft_window_shift
self
.
num_frames
=
train_clip_frames
+
context_frames
...
...
@@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech):
assert
loaded_sr
==
sr
# Pad audio to the right size.
frames
=
math
.
ceil
(
float
(
audio
.
size
)
/
fft_window_shift
)
frames
=
int
(
np
.
ceil
(
float
(
audio
.
size
)
/
fft_window_shift
)
)
fft_padding
=
(
fft_size
-
fft_window_shift
)
//
2
desired_length
=
frames
*
fft_window_shift
+
fft_padding
*
2
pad_amount
=
(
desired_length
-
audio
.
size
)
//
2
...
...
@@ -125,35 +123,6 @@ class Subset(dataset.Dataset):
return
len
(
self
.
indices
)
class
DistributedSampler
(
Sampler
):
def
__init__
(
self
,
dataset_size
,
num_trainers
,
rank
,
shuffle
=
True
):
self
.
dataset_size
=
dataset_size
self
.
num_trainers
=
num_trainers
self
.
rank
=
rank
self
.
num_samples
=
int
(
math
.
ceil
(
dataset_size
/
num_trainers
))
self
.
total_size
=
self
.
num_samples
*
num_trainers
assert
self
.
total_size
>=
self
.
dataset_size
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
indices
=
list
(
range
(
self
.
dataset_size
))
if
self
.
shuffle
:
random
.
shuffle
(
indices
)
# Append extra samples to make it evenly distributed on all trainers.
indices
+=
indices
[:(
self
.
total_size
-
self
.
dataset_size
)]
assert
len
(
indices
)
==
self
.
total_size
# Subset samples for each trainer.
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_trainers
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
class
LJSpeech
:
def
__init__
(
self
,
config
,
nranks
,
rank
):
place
=
fluid
.
CUDAPlace
(
rank
)
if
config
.
use_gpu
else
fluid
.
CPUPlace
()
...
...
parakeet/models/wavenet/ops.py
已删除
100644 → 0
浏览文件 @
b15c3134
import
paddle
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
import
numpy
as
np
import
weight_norm
def
Embedding
(
name_scope
,
num_embeddings
,
embed_dim
,
padding_idx
=
None
,
std
=
0.1
,
dtype
=
"float32"
):
# param attrs
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
std
))
layer
=
dg
.
Embedding
(
name_scope
,
(
num_embeddings
,
embed_dim
),
padding_idx
=
padding_idx
,
param_attr
=
weight_attr
,
dtype
=
dtype
)
return
layer
def
FC
(
name_scope
,
in_features
,
size
,
num_flatten_dims
=
1
,
relu
=
False
,
dropout
=
0.0
,
act
=
None
,
dtype
=
"float32"
):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if
isinstance
(
in_features
,
int
):
in_features
=
[
in_features
]
stds
=
[
np
.
sqrt
((
1.0
-
dropout
)
/
in_feature
)
for
in_feature
in
in_features
]
if
relu
:
stds
=
[
std
*
np
.
sqrt
(
2.0
)
for
std
in
stds
]
weight_inits
=
[
fluid
.
initializer
.
NormalInitializer
(
scale
=
std
)
for
std
in
stds
]
bias_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
# param attrs
weight_attrs
=
[
fluid
.
ParamAttr
(
initializer
=
init
)
for
init
in
weight_inits
]
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
layer
=
weight_norm
.
FC
(
name_scope
,
size
,
num_flatten_dims
=
num_flatten_dims
,
param_attr
=
weight_attrs
,
bias_attr
=
bias_attr
,
act
=
act
,
dtype
=
dtype
)
return
layer
def
Conv1D
(
name_scope
,
in_channels
,
num_filters
,
filter_size
=
2
,
dilation
=
1
,
groups
=
None
,
causal
=
False
,
std_mul
=
1.0
,
dropout
=
0.0
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels)))
"""
# std
std
=
np
.
sqrt
((
std_mul
*
(
1.0
-
dropout
))
/
(
filter_size
*
in_channels
))
weight_init
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
std
)
bias_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
# param attrs
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
layer
=
weight_norm
.
Conv1D
(
name_scope
,
num_filters
,
filter_size
,
dilation
,
groups
=
groups
,
causal
=
causal
,
param_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
class
Conv1D_GU
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
conditioner_dim
,
in_channels
,
num_filters
,
filter_size
,
dilation
,
causal
=
False
,
residual
=
True
,
dtype
=
"float32"
):
super
(
Conv1D_GU
,
self
).
__init__
(
name_scope
,
dtype
=
dtype
)
self
.
conditioner_dim
=
conditioner_dim
self
.
in_channels
=
in_channels
self
.
num_filters
=
num_filters
self
.
filter_size
=
filter_size
self
.
dilation
=
dilation
self
.
causal
=
causal
self
.
residual
=
residual
if
residual
:
assert
(
in_channels
==
num_filters
),
"this block uses residual connection"
\
"the input_channels should equals num_filters"
self
.
conv
=
Conv1D
(
self
.
full_name
(),
in_channels
,
2
*
num_filters
,
filter_size
,
dilation
,
causal
=
causal
,
dtype
=
dtype
)
self
.
fc
=
Conv1D
(
self
.
full_name
(),
conditioner_dim
,
2
*
num_filters
,
filter_size
=
1
,
dilation
=
1
,
causal
=
False
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1DGLU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual
=
x
x
=
self
.
conv
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
add_input
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual
=
x
# add step input and produce step output
x
=
self
.
conv
.
add_input
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
,
padding
=
0
,
stride
=
1
,
dilation
=
1
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
val
=
1.0
/
(
filter_size
[
0
]
*
filter_size
[
1
])
weight_init
=
fluid
.
initializer
.
ConstantInitializer
(
val
)
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
layer
=
weight_norm
.
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
=
filter_size
,
padding
=
padding
,
stride
=
stride
,
dilation
=
dilation
,
param_attr
=
weight_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
parakeet/models/wavenet/wavenet.py
浏览文件 @
98841ee4
...
...
@@ -4,12 +4,12 @@ import time
import
librosa
import
numpy
as
np
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
from
paddle
import
fluid
import
utils
from
data
import
LJSpeech
from
wavenet_modules
import
WaveNetModule
,
debug
from
wavenet_modules
import
WaveNetModule
class
WaveNet
():
...
...
@@ -33,18 +33,6 @@ class WaveNet():
self
.
trainloader
=
dataset
.
trainloader
self
.
validloader
=
dataset
.
validloader
# if self.rank == 0:
# for i, (audios, mels, ids) in enumerate(self.validloader()):
# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype))
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
#
# for i, (audios, mels, ids) in enumerate(self.trainloader):
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
wavenet
=
WaveNetModule
(
"wavenet"
,
config
,
self
.
rank
)
# Dry run once to create and initalize all necessary parameters.
...
...
@@ -139,8 +127,8 @@ class WaveNet():
self
.
wavenet
.
eval
()
total_loss
=
[]
start_time
=
time
.
time
()
sample_audios
=
[]
start_time
=
time
.
time
()
for
audios
,
mels
,
audio_starts
in
self
.
validloader
():
loss
,
sample_audio
=
self
.
wavenet
(
audios
,
mels
,
audio_starts
,
True
)
total_loss
.
append
(
float
(
loss
.
numpy
()))
...
...
@@ -160,11 +148,6 @@ class WaveNet():
tb
.
add_audio
(
"Teacher-Forced-Audio-1"
,
sample_audios
[
1
].
numpy
(),
iteration
,
sample_rate
=
self
.
config
.
sample_rate
)
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
@
dg
.
no_grad
def
infer
(
self
,
iteration
):
self
.
wavenet
.
eval
()
...
...
@@ -186,3 +169,8 @@ class WaveNet():
syn_audio
.
shape
,
syn_time
))
librosa
.
output
.
write_wav
(
filename
,
syn_audio
,
sr
=
config
.
sample_rate
)
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/models/wavenet/wavenet_modules.py
浏览文件 @
98841ee4
import
itertools
import
math
import
numpy
as
np
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
import
ops
import
weight_norm
from
paddle
import
fluid
from
parakeet.modules
import
conv
,
modules
def
get_padding
(
filter_size
,
stride
,
padding_type
=
'same'
):
...
...
@@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'):
return
padding
def
debug
(
x
,
var_name
,
rank
,
verbose
=
False
):
if
not
verbose
and
rank
!=
0
:
return
dim
=
len
(
x
.
shape
)
if
not
isinstance
(
x
,
np
.
ndarray
):
x
=
x
.
numpy
()
if
dim
==
1
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
))
elif
dim
==
2
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
[:,
:
5
]))
elif
dim
==
3
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
[:,
:
5
,
0
]))
else
:
print
(
"Rank"
,
rank
,
var_name
,
"shape"
,
x
.
shape
)
def
extract_slices
(
x
,
audio_starts
,
audio_length
,
rank
):
slices
=
[]
for
i
in
range
(
x
.
shape
[
0
]):
...
...
@@ -58,7 +40,7 @@ class Conditioner(dg.Layer):
stride
=
(
up_scale
,
1
)
padding
=
get_padding
(
filter_sizes
[
i
],
stride
)
self
.
deconvs
.
append
(
op
s
.
Conv2DTranspose
(
module
s
.
Conv2DTranspose
(
self
.
full_name
(),
num_filters
=
1
,
filter_size
=
filter_sizes
[
i
],
...
...
@@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer):
print
(
"context_size"
,
self
.
context_size
)
if
config
.
loss_type
==
"softmax"
:
self
.
embedding_fc
=
op
s
.
Embedding
(
self
.
embedding_fc
=
module
s
.
Embedding
(
self
.
full_name
(),
num_embeddings
=
config
.
num_channels
,
embed_dim
=
config
.
residual_channels
)
embed_dim
=
config
.
residual_channels
,
std
=
0.1
)
elif
config
.
loss_type
==
"mix-gaussian-pdf"
:
self
.
embedding_fc
=
op
s
.
FC
(
self
.
embedding_fc
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
1
,
size
=
config
.
residual_channels
,
...
...
@@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer):
self
.
dilated_causal_convs
=
[]
for
dilation
in
self
.
dilations
:
self
.
dilated_causal_convs
.
append
(
op
s
.
Conv1D_GU
(
module
s
.
Conv1D_GU
(
self
.
full_name
(),
conditioner_dim
=
config
.
mel_bands
,
in_channels
=
config
.
residual_channels
,
...
...
@@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer):
for
i
,
layer
in
enumerate
(
self
.
dilated_causal_convs
):
self
.
add_sublayer
(
"dilated_causal_conv_{}"
.
format
(
i
),
layer
)
self
.
fc1
=
op
s
.
FC
(
self
.
fc1
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
residual_channels
,
size
=
config
.
skip_channels
,
...
...
@@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer):
relu
=
True
,
act
=
"relu"
)
self
.
fc2
=
op
s
.
FC
(
self
.
fc2
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
config
.
skip_channels
,
...
...
@@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer):
act
=
"relu"
)
if
config
.
loss_type
==
"softmax"
:
self
.
fc3
=
op
s
.
FC
(
self
.
fc3
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
config
.
num_channels
,
num_flatten_dims
=
2
,
relu
=
False
)
elif
config
.
loss_type
==
"mix-gaussian-pdf"
:
self
.
fc3
=
op
s
.
FC
(
self
.
fc3
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
3
*
config
.
num_mixtures
,
...
...
@@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer):
return
samples
def
sample_mix_gaussian
(
self
,
mix_parameters
):
# mix_parameters reshape from [bs,
13799
, 3 * num_mixtures]
# to [bs *
13799
, 3 * num_mixtures].
# mix_parameters reshape from [bs,
len
, 3 * num_mixtures]
# to [bs *
len
, 3 * num_mixtures].
batch
,
length
,
hidden
=
mix_parameters
.
shape
mix_param_2d
=
fluid
.
layers
.
reshape
(
mix_parameters
,
[
batch
*
length
,
hidden
])
...
...
@@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer):
mu_comp
=
fluid
.
layers
.
gather_nd
(
mu
,
comp_samples
)
s_comp
=
fluid
.
layers
.
gather_nd
(
s
,
comp_samples
)
# N(0, 1)
Normal S
ample.
# N(0, 1)
normal s
ample.
u
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
batch
*
length
])
samples
=
mu_comp
+
u
*
s_comp
samples
=
fluid
.
layers
.
clip
(
samples
,
min
=-
1.0
,
max
=
1.0
)
...
...
@@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer):
return
samples
def
softmax_loss
(
self
,
targets
,
mix_parameters
):
# targets: [bs, 13799] -> [bs, 11752]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets
=
targets
[:,
self
.
context_size
:]
mix_parameters
=
mix_parameters
[:,
self
.
context_size
:,
:]
...
...
@@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer):
quantized
=
fluid
.
layers
.
cast
(
(
targets
+
1.0
)
/
2.0
*
num_channels
,
dtype
=
"int64"
)
# per_sample_loss shape: [bs,
17952
, 1]
# per_sample_loss shape: [bs,
len
, 1]
per_sample_loss
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
mix_parameters
,
label
=
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
loss
=
fluid
.
layers
.
reduce_mean
(
per_sample_loss
)
#debug(loss, "softmax loss", self.rank)
return
loss
def
mixture_density_loss
(
self
,
targets
,
mix_parameters
,
log_scale_min
):
# targets: [bs,
13799] -> [bs, 11752
]
# mix_params: [bs,
13799, 3] -> [bs, 11752, 3
]
# targets: [bs,
len
]
# mix_params: [bs,
len, 3 * num_mixture
]
targets
=
targets
[:,
self
.
context_size
:]
mix_parameters
=
mix_parameters
[:,
self
.
context_size
:,
:]
# log_s: [bs, 11752, num_mixture]
logits_pi
,
mu
,
log_s
=
fluid
.
layers
.
split
(
mix_parameters
,
num_or_sections
=
3
,
dim
=-
1
)
# log_s: [bs, len, num_mixture]
logits_pi
,
mu
,
log_s
=
fluid
.
layers
.
split
(
mix_parameters
,
num_or_sections
=
3
,
dim
=-
1
)
pi
=
fluid
.
layers
.
softmax
(
logits_pi
,
axis
=-
1
)
log_s
=
fluid
.
layers
.
clip
(
log_s
,
min
=
log_scale_min
,
max
=
100.0
)
...
...
@@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer):
targets
=
fluid
.
layers
.
expand
(
targets
,
[
1
,
1
,
self
.
config
.
num_mixtures
])
x_std
=
inv_s
*
(
targets
-
mu
)
exponent
=
fluid
.
layers
.
exp
(
-
0.5
*
x_std
*
x_std
)
# pdf_x: [bs, 11752, 1]
pdf_x
=
1.0
/
np
.
sqrt
(
2.0
*
np
.
pi
)
*
inv_s
*
exponent
pdf_x
=
pi
*
pdf_x
# pdf_x: [bs,
11752
]
# pdf_x: [bs,
len
]
pdf_x
=
fluid
.
layers
.
reduce_sum
(
pdf_x
,
dim
=-
1
)
per_sample_loss
=
0.0
-
fluid
.
layers
.
log
(
pdf_x
+
1e-9
)
...
...
@@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer):
return
loss
def
forward
(
self
,
audios
,
mels
,
audio_starts
,
sample
=
False
):
# audios: [bs, 13800], mels: [bs, full_frame_length, 80]
# audio_starts: [bs]
# Build conditioner based on mels.
full_conditioner
=
self
.
conditioner
(
mels
)
...
...
@@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer):
conditioner
=
extract_slices
(
full_conditioner
,
audio_starts
,
audio_length
,
self
.
rank
)
# input_audio, target_audio: [bs,
13799
]
# input_audio, target_audio: [bs,
len
]
input_audios
=
audios
[:,
:
-
1
]
target_audios
=
audios
[:,
1
:]
# conditioner: [bs,
13799, 80
]
# conditioner: [bs,
len, mel_bands
]
conditioner
=
conditioner
[:,
1
:,
:]
loss_type
=
self
.
config
.
loss_type
# layer_input: [bs, 13799, 128]
if
loss_type
==
"softmax"
:
input_audios
=
fluid
.
layers
.
clip
(
input_audios
,
min
=-
1.0
,
max
=
0.99999
)
...
...
@@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer):
quantized
=
fluid
.
layers
.
cast
(
(
input_audios
+
1.0
)
/
2.0
*
self
.
config
.
num_channels
,
dtype
=
"int64"
)
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
elif
loss_type
==
"mix-gaussian-pdf"
:
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
input_audios
,
2
))
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
input_audios
,
2
))
else
:
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
# layer_input: [bs, res_channel, 1, 13799]
layer_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
layer_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# conditioner: [bs, mel_bands, 1, 13799]
conditioner
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
conditioner
,
perm
=
[
0
,
2
,
1
]),
2
)
# layer_input: [bs, res_channel, 1, len]
layer_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
layer_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# conditioner: [bs, mel_bands, 1, len]
conditioner
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
conditioner
,
perm
=
[
0
,
2
,
1
]),
2
)
# layer_input: [bs, res_channel, 1, 13799]
# skip: [bs, res_channel, 1, 13799]
skip
=
None
for
i
,
layer
in
enumerate
(
self
.
dilated_causal_convs
):
# layer_input: [bs, res_channel, 1, len]
# skip: [bs, res_channel, 1, len]
layer_input
,
skip
=
layer
(
layer_input
,
skip
,
conditioner
)
#debug(layer_input, "layer_input_" + str(i), self.rank)
#debug(skip, "skip_" + str(i), self.rank)
# Reshape skip to [bs, 13799, res_channel]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
#debug(skip, "skip", self.rank)
# mix_param: [bs, 13799, 3 * num_mixtures]
# Reshape skip to [bs, len, res_channel]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
mix_parameters
=
self
.
fc3
(
self
.
fc2
(
self
.
fc1
(
skip
)))
# Sample teacher-forced audio.
...
...
@@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer):
else
:
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
#debug(sample_audios, "sample_audios", self.rank)
# Calculate mix-gaussian density loss.
# padding is all zero.
# target_audio: [bs, 13799].
# mix_params: [bs, 13799, 3].
if
loss_type
==
"softmax"
:
loss
=
self
.
softmax_loss
(
target_audios
,
mix_parameters
)
elif
loss_type
==
"mix-gaussian-pdf"
:
...
...
@@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer):
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
#print("Rank {}, loss {}".format(self.rank, loss.numpy()))
return
loss
,
sample_audios
def
synthesize
(
self
,
mels
):
self
.
start_new_sequence
()
print
(
"input mels shape"
,
mels
.
shape
)
# mels: [bs=1, n_frames, 80]
# conditioner: [1, n_frames * samples_per_frame, 80]
# Should I move forward by one sample? No difference
# Append context frame to mels
bs
,
n_frames
,
mel_bands
=
mels
.
shape
#num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift))
#silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32")
#inf_mels = fluid.layers.concat([silence, mels], axis=1)
#print("padded mels shape", inf_mels.shape)
#conditioner = self.conditioner(inf_mels)[:, self.context_size:, :]
conditioner
=
self
.
conditioner
(
mels
)
time_steps
=
conditioner
.
shape
[
1
]
print
(
"Total steps"
,
time_steps
)
print
(
"input mels shape"
,
mels
.
shape
)
print
(
"Total synthesis steps"
,
time_steps
)
loss_type
=
self
.
config
.
loss_type
audio_samples
=
[]
...
...
@@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer):
if
i
%
100
==
0
:
print
(
"Step"
,
i
)
#
c
onvert from real value sample to audio embedding.
#
[bs, 1, 128
]
#
C
onvert from real value sample to audio embedding.
#
audio_input: [bs, 1, channel
]
if
loss_type
==
"softmax"
:
current_sample
=
fluid
.
layers
.
clip
(
current_sample
,
min
=-
1.0
,
max
=
0.99999
)
...
...
@@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer):
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
# [bs, 128, 1, 1]
audio_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
audio_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# [bs, 80]
# [bs, channel, 1, 1]
audio_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
audio_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# [bs, mel_bands]
cond_input
=
conditioner
[:,
i
,
:]
# [bs,
80
, 1, 1]
# [bs,
mel_bands
, 1, 1]
cond_input
=
fluid
.
layers
.
reshape
(
cond_input
,
cond_input
.
shape
+
[
1
,
1
])
skip
=
None
for
layer
in
self
.
dilated_causal_convs
:
audio_input
,
skip
=
layer
.
add_input
(
audio_input
,
skip
,
cond_input
)
audio_input
,
skip
=
layer
.
add_input
(
audio_input
,
skip
,
cond_input
)
# [bs, 1,
128
]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
# [bs, 1, 3]
# [bs, 1,
channel
]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
mix_parameters
=
self
.
fc3
(
self
.
fc2
(
self
.
fc1
(
skip
)))
if
loss_type
==
"softmax"
:
sample
=
self
.
sample_softmax
(
mix_parameters
)
...
...
@@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer):
current_sample
=
fluid
.
layers
.
reshape
(
current_sample
,
current_sample
.
shape
+
[
1
,
1
])
# syn_audio:
(num_samples,)
# syn_audio:
[num_samples]
syn_audio
=
fluid
.
layers
.
concat
(
audio_samples
,
axis
=
0
).
numpy
()
return
syn_audio
def
start_new_sequence
(
self
):
for
layer
in
self
.
sublayers
():
if
isinstance
(
layer
,
weight_norm
.
Conv1D
):
if
isinstance
(
layer
,
conv
.
Conv1D
):
layer
.
start_new_sequence
()
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/models/wavenet/weight_norm.py
已删除
100644 → 0
浏览文件 @
b15c3134
此差异已折叠。
点击以展开。
parakeet/modules/modules.py
浏览文件 @
98841ee4
...
...
@@ -26,6 +26,7 @@ def FC(name_scope,
in_features
,
size
,
num_flatten_dims
=
1
,
relu
=
False
,
dropout
=
0.0
,
epsilon
=
1e-30
,
act
=
None
,
...
...
@@ -39,7 +40,11 @@ def FC(name_scope,
# stds
if
isinstance
(
in_features
,
int
):
in_features
=
[
in_features
]
stds
=
[
np
.
sqrt
((
1
-
dropout
)
/
in_feature
)
for
in_feature
in
in_features
]
if
relu
:
stds
=
[
std
*
np
.
sqrt
(
2.0
)
for
std
in
stds
]
weight_inits
=
[
fluid
.
initializer
.
NormalInitializer
(
scale
=
std
)
for
std
in
stds
]
...
...
@@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer):
return
out
else
:
raise
Exception
(
"Then you can just use position rate at init"
)
class
Conv1D_GU
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
conditioner_dim
,
in_channels
,
num_filters
,
filter_size
,
dilation
,
causal
=
False
,
residual
=
True
,
dtype
=
"float32"
):
super
(
Conv1D_GU
,
self
).
__init__
(
name_scope
,
dtype
=
dtype
)
self
.
conditioner_dim
=
conditioner_dim
self
.
in_channels
=
in_channels
self
.
num_filters
=
num_filters
self
.
filter_size
=
filter_size
self
.
dilation
=
dilation
self
.
causal
=
causal
self
.
residual
=
residual
if
residual
:
assert
(
in_channels
==
num_filters
),
"this block uses residual connection"
\
"the input_channels should equals num_filters"
self
.
conv
=
Conv1D
(
self
.
full_name
(),
in_channels
,
2
*
num_filters
,
filter_size
,
dilation
,
causal
=
causal
,
dtype
=
dtype
)
self
.
fc
=
Conv1D
(
self
.
full_name
(),
conditioner_dim
,
2
*
num_filters
,
filter_size
=
1
,
dilation
=
1
,
causal
=
False
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
skip (Variable): Shape(B, C_in, 1, T), skip connection.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1D_GU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
C_out means the output channels of Conv1D_GU.
skip (Variable): Shape(B, C_out, 1, T), skip connection.
"""
residual
=
x
x
=
self
.
conv
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
add_input
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
skip: shape(B, num_filters, 1, time_steps), skip connection
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
skip: skip connection, same shape as x
"""
residual
=
x
# add step input and produce step output
x
=
self
.
conv
.
add_input
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
,
padding
=
0
,
stride
=
1
,
dilation
=
1
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
val
=
1.0
/
(
filter_size
[
0
]
*
filter_size
[
1
])
weight_init
=
fluid
.
initializer
.
ConstantInitializer
(
val
)
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
layer
=
weight_norm
.
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
=
filter_size
,
padding
=
padding
,
stride
=
stride
,
dilation
=
dilation
,
param_attr
=
weight_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录