Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
98841ee4
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
14
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98841ee4
编写于
12月 02, 2019
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code
上级
b15c3134
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
314 addition
and
1322 deletion
+314
-1322
parakeet/data/sampler.py
parakeet/data/sampler.py
+30
-1
parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
...models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
+32
-0
parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
...keet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
+31
-0
parakeet/models/wavenet/data.py
parakeet/models/wavenet/data.py
+3
-34
parakeet/models/wavenet/ops.py
parakeet/models/wavenet/ops.py
+0
-249
parakeet/models/wavenet/wavenet.py
parakeet/models/wavenet/wavenet.py
+8
-20
parakeet/models/wavenet/wavenet_modules.py
parakeet/models/wavenet/wavenet_modules.py
+56
-98
parakeet/models/wavenet/weight_norm.py
parakeet/models/wavenet/weight_norm.py
+0
-920
parakeet/modules/modules.py
parakeet/modules/modules.py
+154
-0
未找到文件。
parakeet/data/sampler.py
浏览文件 @
98841ee4
...
...
@@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler):
return
self
.
num_samples
class
DistributedSampler
(
Sampler
):
def
__init__
(
self
,
dataset_size
,
num_trainers
,
rank
,
shuffle
=
True
):
self
.
dataset_size
=
dataset_size
self
.
num_trainers
=
num_trainers
self
.
rank
=
rank
self
.
num_samples
=
int
(
np
.
ceil
(
dataset_size
/
num_trainers
))
self
.
total_size
=
self
.
num_samples
*
num_trainers
assert
self
.
total_size
>=
self
.
dataset_size
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
indices
=
list
(
range
(
self
.
dataset_size
))
if
self
.
shuffle
:
random
.
shuffle
(
indices
)
# Append extra samples to make it evenly distributed on all trainers.
indices
+=
indices
[:(
self
.
total_size
-
self
.
dataset_size
)]
assert
len
(
indices
)
==
self
.
total_size
# Subset samples for each trainer.
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_trainers
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
class
BatchSampler
(
Sampler
):
r
"""Wraps another sampler to yield a mini-batch of indices.
Args:
...
...
parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
0 → 100644
浏览文件 @
98841ee4
valid_size
:
16
train_clip_second
:
0.5
sample_rate
:
22050
fft_window_shift
:
256
fft_window_size
:
1024
fft_size
:
2048
mel_bands
:
80
seed
:
1
batch_size
:
8
test_every
:
2000
save_every
:
10000
max_iterations
:
2000000
layers
:
30
kernel_width
:
2
dilation_block
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
residual_channels
:
128
skip_channels
:
128
loss_type
:
mix-gaussian-pdf
num_mixtures
:
10
log_scale_min
:
-9.0
conditioner
:
filter_sizes
:
[[
32
,
3
],
[
32
,
3
]]
upsample_factors
:
[
16
,
16
]
learning_rate
:
0.001
gradient_max_norm
:
100.0
anneal
:
every
:
200000
rate
:
0.5
parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
0 → 100644
浏览文件 @
98841ee4
valid_size
:
16
train_clip_second
:
0.5
sample_rate
:
22050
fft_window_shift
:
256
fft_window_size
:
1024
fft_size
:
2048
mel_bands
:
80
seed
:
1
batch_size
:
8
test_every
:
2000
save_every
:
10000
max_iterations
:
2000000
layers
:
30
kernel_width
:
2
dilation_block
:
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
]
residual_channels
:
128
skip_channels
:
128
loss_type
:
softmax
num_channels
:
2048
conditioner
:
filter_sizes
:
[[
32
,
3
],
[
32
,
3
]]
upsample_factors
:
[
16
,
16
]
learning_rate
:
0.001
gradient_max_norm
:
100.0
anneal
:
every
:
200000
rate
:
0.5
parakeet/models/wavenet/data.py
浏览文件 @
98841ee4
import
math
import
os
import
random
import
librosa
...
...
@@ -9,7 +7,7 @@ from paddle import fluid
import
utils
from
parakeet.datasets
import
ljspeech
from
parakeet.data
import
dataset
from
parakeet.data.sampler
import
Sampler
,
BatchSampler
,
Sequential
Sampler
from
parakeet.data.sampler
import
DistributedSampler
,
Batch
Sampler
from
parakeet.data.datacargo
import
DataCargo
...
...
@@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech):
self
.
fft_window_shift
=
config
.
fft_window_shift
# Calculate context frames.
frames_per_second
=
config
.
sample_rate
//
self
.
fft_window_shift
train_clip_frames
=
int
(
math
.
ceil
(
train_clip_frames
=
int
(
np
.
ceil
(
config
.
train_clip_second
*
frames_per_second
))
context_frames
=
config
.
context_size
//
self
.
fft_window_shift
self
.
num_frames
=
train_clip_frames
+
context_frames
...
...
@@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech):
assert
loaded_sr
==
sr
# Pad audio to the right size.
frames
=
math
.
ceil
(
float
(
audio
.
size
)
/
fft_window_shift
)
frames
=
int
(
np
.
ceil
(
float
(
audio
.
size
)
/
fft_window_shift
)
)
fft_padding
=
(
fft_size
-
fft_window_shift
)
//
2
desired_length
=
frames
*
fft_window_shift
+
fft_padding
*
2
pad_amount
=
(
desired_length
-
audio
.
size
)
//
2
...
...
@@ -125,35 +123,6 @@ class Subset(dataset.Dataset):
return
len
(
self
.
indices
)
class
DistributedSampler
(
Sampler
):
def
__init__
(
self
,
dataset_size
,
num_trainers
,
rank
,
shuffle
=
True
):
self
.
dataset_size
=
dataset_size
self
.
num_trainers
=
num_trainers
self
.
rank
=
rank
self
.
num_samples
=
int
(
math
.
ceil
(
dataset_size
/
num_trainers
))
self
.
total_size
=
self
.
num_samples
*
num_trainers
assert
self
.
total_size
>=
self
.
dataset_size
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
indices
=
list
(
range
(
self
.
dataset_size
))
if
self
.
shuffle
:
random
.
shuffle
(
indices
)
# Append extra samples to make it evenly distributed on all trainers.
indices
+=
indices
[:(
self
.
total_size
-
self
.
dataset_size
)]
assert
len
(
indices
)
==
self
.
total_size
# Subset samples for each trainer.
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_trainers
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
class
LJSpeech
:
def
__init__
(
self
,
config
,
nranks
,
rank
):
place
=
fluid
.
CUDAPlace
(
rank
)
if
config
.
use_gpu
else
fluid
.
CPUPlace
()
...
...
parakeet/models/wavenet/ops.py
已删除
100644 → 0
浏览文件 @
b15c3134
import
paddle
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
import
numpy
as
np
import
weight_norm
def
Embedding
(
name_scope
,
num_embeddings
,
embed_dim
,
padding_idx
=
None
,
std
=
0.1
,
dtype
=
"float32"
):
# param attrs
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
std
))
layer
=
dg
.
Embedding
(
name_scope
,
(
num_embeddings
,
embed_dim
),
padding_idx
=
padding_idx
,
param_attr
=
weight_attr
,
dtype
=
dtype
)
return
layer
def
FC
(
name_scope
,
in_features
,
size
,
num_flatten_dims
=
1
,
relu
=
False
,
dropout
=
0.0
,
act
=
None
,
dtype
=
"float32"
):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if
isinstance
(
in_features
,
int
):
in_features
=
[
in_features
]
stds
=
[
np
.
sqrt
((
1.0
-
dropout
)
/
in_feature
)
for
in_feature
in
in_features
]
if
relu
:
stds
=
[
std
*
np
.
sqrt
(
2.0
)
for
std
in
stds
]
weight_inits
=
[
fluid
.
initializer
.
NormalInitializer
(
scale
=
std
)
for
std
in
stds
]
bias_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
# param attrs
weight_attrs
=
[
fluid
.
ParamAttr
(
initializer
=
init
)
for
init
in
weight_inits
]
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
layer
=
weight_norm
.
FC
(
name_scope
,
size
,
num_flatten_dims
=
num_flatten_dims
,
param_attr
=
weight_attrs
,
bias_attr
=
bias_attr
,
act
=
act
,
dtype
=
dtype
)
return
layer
def
Conv1D
(
name_scope
,
in_channels
,
num_filters
,
filter_size
=
2
,
dilation
=
1
,
groups
=
None
,
causal
=
False
,
std_mul
=
1.0
,
dropout
=
0.0
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels)))
"""
# std
std
=
np
.
sqrt
((
std_mul
*
(
1.0
-
dropout
))
/
(
filter_size
*
in_channels
))
weight_init
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
std
)
bias_init
=
fluid
.
initializer
.
ConstantInitializer
(
0.0
)
# param attrs
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
bias_init
)
layer
=
weight_norm
.
Conv1D
(
name_scope
,
num_filters
,
filter_size
,
dilation
,
groups
=
groups
,
causal
=
causal
,
param_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
class
Conv1D_GU
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
conditioner_dim
,
in_channels
,
num_filters
,
filter_size
,
dilation
,
causal
=
False
,
residual
=
True
,
dtype
=
"float32"
):
super
(
Conv1D_GU
,
self
).
__init__
(
name_scope
,
dtype
=
dtype
)
self
.
conditioner_dim
=
conditioner_dim
self
.
in_channels
=
in_channels
self
.
num_filters
=
num_filters
self
.
filter_size
=
filter_size
self
.
dilation
=
dilation
self
.
causal
=
causal
self
.
residual
=
residual
if
residual
:
assert
(
in_channels
==
num_filters
),
"this block uses residual connection"
\
"the input_channels should equals num_filters"
self
.
conv
=
Conv1D
(
self
.
full_name
(),
in_channels
,
2
*
num_filters
,
filter_size
,
dilation
,
causal
=
causal
,
dtype
=
dtype
)
self
.
fc
=
Conv1D
(
self
.
full_name
(),
conditioner_dim
,
2
*
num_filters
,
filter_size
=
1
,
dilation
=
1
,
causal
=
False
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1DGLU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual
=
x
x
=
self
.
conv
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
add_input
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual
=
x
# add step input and produce step output
x
=
self
.
conv
.
add_input
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
,
padding
=
0
,
stride
=
1
,
dilation
=
1
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
val
=
1.0
/
(
filter_size
[
0
]
*
filter_size
[
1
])
weight_init
=
fluid
.
initializer
.
ConstantInitializer
(
val
)
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
layer
=
weight_norm
.
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
=
filter_size
,
padding
=
padding
,
stride
=
stride
,
dilation
=
dilation
,
param_attr
=
weight_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
parakeet/models/wavenet/wavenet.py
浏览文件 @
98841ee4
...
...
@@ -4,12 +4,12 @@ import time
import
librosa
import
numpy
as
np
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
from
paddle
import
fluid
import
utils
from
data
import
LJSpeech
from
wavenet_modules
import
WaveNetModule
,
debug
from
wavenet_modules
import
WaveNetModule
class
WaveNet
():
...
...
@@ -33,18 +33,6 @@ class WaveNet():
self
.
trainloader
=
dataset
.
trainloader
self
.
validloader
=
dataset
.
validloader
# if self.rank == 0:
# for i, (audios, mels, ids) in enumerate(self.validloader()):
# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype))
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
#
# for i, (audios, mels, ids) in enumerate(self.trainloader):
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
wavenet
=
WaveNetModule
(
"wavenet"
,
config
,
self
.
rank
)
# Dry run once to create and initalize all necessary parameters.
...
...
@@ -139,8 +127,8 @@ class WaveNet():
self
.
wavenet
.
eval
()
total_loss
=
[]
start_time
=
time
.
time
()
sample_audios
=
[]
start_time
=
time
.
time
()
for
audios
,
mels
,
audio_starts
in
self
.
validloader
():
loss
,
sample_audio
=
self
.
wavenet
(
audios
,
mels
,
audio_starts
,
True
)
total_loss
.
append
(
float
(
loss
.
numpy
()))
...
...
@@ -160,11 +148,6 @@ class WaveNet():
tb
.
add_audio
(
"Teacher-Forced-Audio-1"
,
sample_audios
[
1
].
numpy
(),
iteration
,
sample_rate
=
self
.
config
.
sample_rate
)
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
@
dg
.
no_grad
def
infer
(
self
,
iteration
):
self
.
wavenet
.
eval
()
...
...
@@ -186,3 +169,8 @@ class WaveNet():
syn_audio
.
shape
,
syn_time
))
librosa
.
output
.
write_wav
(
filename
,
syn_audio
,
sr
=
config
.
sample_rate
)
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/models/wavenet/wavenet_modules.py
浏览文件 @
98841ee4
import
itertools
import
math
import
numpy
as
np
from
paddle
import
fluid
import
paddle.fluid.dygraph
as
dg
import
ops
import
weight_norm
from
paddle
import
fluid
from
parakeet.modules
import
conv
,
modules
def
get_padding
(
filter_size
,
stride
,
padding_type
=
'same'
):
...
...
@@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'):
return
padding
def
debug
(
x
,
var_name
,
rank
,
verbose
=
False
):
if
not
verbose
and
rank
!=
0
:
return
dim
=
len
(
x
.
shape
)
if
not
isinstance
(
x
,
np
.
ndarray
):
x
=
x
.
numpy
()
if
dim
==
1
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
))
elif
dim
==
2
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
[:,
:
5
]))
elif
dim
==
3
:
print
(
"Rank {}"
.
format
(
rank
),
var_name
,
"shape {}, value {}"
.
format
(
x
.
shape
,
x
[:,
:
5
,
0
]))
else
:
print
(
"Rank"
,
rank
,
var_name
,
"shape"
,
x
.
shape
)
def
extract_slices
(
x
,
audio_starts
,
audio_length
,
rank
):
slices
=
[]
for
i
in
range
(
x
.
shape
[
0
]):
...
...
@@ -58,7 +40,7 @@ class Conditioner(dg.Layer):
stride
=
(
up_scale
,
1
)
padding
=
get_padding
(
filter_sizes
[
i
],
stride
)
self
.
deconvs
.
append
(
op
s
.
Conv2DTranspose
(
module
s
.
Conv2DTranspose
(
self
.
full_name
(),
num_filters
=
1
,
filter_size
=
filter_sizes
[
i
],
...
...
@@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer):
print
(
"context_size"
,
self
.
context_size
)
if
config
.
loss_type
==
"softmax"
:
self
.
embedding_fc
=
op
s
.
Embedding
(
self
.
embedding_fc
=
module
s
.
Embedding
(
self
.
full_name
(),
num_embeddings
=
config
.
num_channels
,
embed_dim
=
config
.
residual_channels
)
embed_dim
=
config
.
residual_channels
,
std
=
0.1
)
elif
config
.
loss_type
==
"mix-gaussian-pdf"
:
self
.
embedding_fc
=
op
s
.
FC
(
self
.
embedding_fc
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
1
,
size
=
config
.
residual_channels
,
...
...
@@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer):
self
.
dilated_causal_convs
=
[]
for
dilation
in
self
.
dilations
:
self
.
dilated_causal_convs
.
append
(
op
s
.
Conv1D_GU
(
module
s
.
Conv1D_GU
(
self
.
full_name
(),
conditioner_dim
=
config
.
mel_bands
,
in_channels
=
config
.
residual_channels
,
...
...
@@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer):
for
i
,
layer
in
enumerate
(
self
.
dilated_causal_convs
):
self
.
add_sublayer
(
"dilated_causal_conv_{}"
.
format
(
i
),
layer
)
self
.
fc1
=
op
s
.
FC
(
self
.
fc1
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
residual_channels
,
size
=
config
.
skip_channels
,
...
...
@@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer):
relu
=
True
,
act
=
"relu"
)
self
.
fc2
=
op
s
.
FC
(
self
.
fc2
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
config
.
skip_channels
,
...
...
@@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer):
act
=
"relu"
)
if
config
.
loss_type
==
"softmax"
:
self
.
fc3
=
op
s
.
FC
(
self
.
fc3
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
config
.
num_channels
,
num_flatten_dims
=
2
,
relu
=
False
)
elif
config
.
loss_type
==
"mix-gaussian-pdf"
:
self
.
fc3
=
op
s
.
FC
(
self
.
fc3
=
module
s
.
FC
(
self
.
full_name
(),
in_features
=
config
.
skip_channels
,
size
=
3
*
config
.
num_mixtures
,
...
...
@@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer):
return
samples
def
sample_mix_gaussian
(
self
,
mix_parameters
):
# mix_parameters reshape from [bs,
13799
, 3 * num_mixtures]
# to [bs *
13799
, 3 * num_mixtures].
# mix_parameters reshape from [bs,
len
, 3 * num_mixtures]
# to [bs *
len
, 3 * num_mixtures].
batch
,
length
,
hidden
=
mix_parameters
.
shape
mix_param_2d
=
fluid
.
layers
.
reshape
(
mix_parameters
,
[
batch
*
length
,
hidden
])
...
...
@@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer):
mu_comp
=
fluid
.
layers
.
gather_nd
(
mu
,
comp_samples
)
s_comp
=
fluid
.
layers
.
gather_nd
(
s
,
comp_samples
)
# N(0, 1)
Normal S
ample.
# N(0, 1)
normal s
ample.
u
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
batch
*
length
])
samples
=
mu_comp
+
u
*
s_comp
samples
=
fluid
.
layers
.
clip
(
samples
,
min
=-
1.0
,
max
=
1.0
)
...
...
@@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer):
return
samples
def
softmax_loss
(
self
,
targets
,
mix_parameters
):
# targets: [bs, 13799] -> [bs, 11752]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets
=
targets
[:,
self
.
context_size
:]
mix_parameters
=
mix_parameters
[:,
self
.
context_size
:,
:]
...
...
@@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer):
quantized
=
fluid
.
layers
.
cast
(
(
targets
+
1.0
)
/
2.0
*
num_channels
,
dtype
=
"int64"
)
# per_sample_loss shape: [bs,
17952
, 1]
# per_sample_loss shape: [bs,
len
, 1]
per_sample_loss
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
mix_parameters
,
label
=
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
loss
=
fluid
.
layers
.
reduce_mean
(
per_sample_loss
)
#debug(loss, "softmax loss", self.rank)
return
loss
def
mixture_density_loss
(
self
,
targets
,
mix_parameters
,
log_scale_min
):
# targets: [bs,
13799] -> [bs, 11752
]
# mix_params: [bs,
13799, 3] -> [bs, 11752, 3
]
# targets: [bs,
len
]
# mix_params: [bs,
len, 3 * num_mixture
]
targets
=
targets
[:,
self
.
context_size
:]
mix_parameters
=
mix_parameters
[:,
self
.
context_size
:,
:]
# log_s: [bs, 11752, num_mixture]
logits_pi
,
mu
,
log_s
=
fluid
.
layers
.
split
(
mix_parameters
,
num_or_sections
=
3
,
dim
=-
1
)
# log_s: [bs, len, num_mixture]
logits_pi
,
mu
,
log_s
=
fluid
.
layers
.
split
(
mix_parameters
,
num_or_sections
=
3
,
dim
=-
1
)
pi
=
fluid
.
layers
.
softmax
(
logits_pi
,
axis
=-
1
)
log_s
=
fluid
.
layers
.
clip
(
log_s
,
min
=
log_scale_min
,
max
=
100.0
)
...
...
@@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer):
targets
=
fluid
.
layers
.
expand
(
targets
,
[
1
,
1
,
self
.
config
.
num_mixtures
])
x_std
=
inv_s
*
(
targets
-
mu
)
exponent
=
fluid
.
layers
.
exp
(
-
0.5
*
x_std
*
x_std
)
# pdf_x: [bs, 11752, 1]
pdf_x
=
1.0
/
np
.
sqrt
(
2.0
*
np
.
pi
)
*
inv_s
*
exponent
pdf_x
=
pi
*
pdf_x
# pdf_x: [bs,
11752
]
# pdf_x: [bs,
len
]
pdf_x
=
fluid
.
layers
.
reduce_sum
(
pdf_x
,
dim
=-
1
)
per_sample_loss
=
0.0
-
fluid
.
layers
.
log
(
pdf_x
+
1e-9
)
...
...
@@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer):
return
loss
def
forward
(
self
,
audios
,
mels
,
audio_starts
,
sample
=
False
):
# audios: [bs, 13800], mels: [bs, full_frame_length, 80]
# audio_starts: [bs]
# Build conditioner based on mels.
full_conditioner
=
self
.
conditioner
(
mels
)
...
...
@@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer):
conditioner
=
extract_slices
(
full_conditioner
,
audio_starts
,
audio_length
,
self
.
rank
)
# input_audio, target_audio: [bs,
13799
]
# input_audio, target_audio: [bs,
len
]
input_audios
=
audios
[:,
:
-
1
]
target_audios
=
audios
[:,
1
:]
# conditioner: [bs,
13799, 80
]
# conditioner: [bs,
len, mel_bands
]
conditioner
=
conditioner
[:,
1
:,
:]
loss_type
=
self
.
config
.
loss_type
# layer_input: [bs, 13799, 128]
if
loss_type
==
"softmax"
:
input_audios
=
fluid
.
layers
.
clip
(
input_audios
,
min
=-
1.0
,
max
=
0.99999
)
...
...
@@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer):
quantized
=
fluid
.
layers
.
cast
(
(
input_audios
+
1.0
)
/
2.0
*
self
.
config
.
num_channels
,
dtype
=
"int64"
)
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
quantized
,
2
))
elif
loss_type
==
"mix-gaussian-pdf"
:
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
input_audios
,
2
))
layer_input
=
self
.
embedding_fc
(
fluid
.
layers
.
unsqueeze
(
input_audios
,
2
))
else
:
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
# layer_input: [bs, res_channel, 1, 13799]
layer_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
layer_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# conditioner: [bs, mel_bands, 1, 13799]
conditioner
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
conditioner
,
perm
=
[
0
,
2
,
1
]),
2
)
# layer_input: [bs, res_channel, 1, len]
layer_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
layer_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# conditioner: [bs, mel_bands, 1, len]
conditioner
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
conditioner
,
perm
=
[
0
,
2
,
1
]),
2
)
# layer_input: [bs, res_channel, 1, 13799]
# skip: [bs, res_channel, 1, 13799]
skip
=
None
for
i
,
layer
in
enumerate
(
self
.
dilated_causal_convs
):
# layer_input: [bs, res_channel, 1, len]
# skip: [bs, res_channel, 1, len]
layer_input
,
skip
=
layer
(
layer_input
,
skip
,
conditioner
)
#debug(layer_input, "layer_input_" + str(i), self.rank)
#debug(skip, "skip_" + str(i), self.rank)
# Reshape skip to [bs, 13799, res_channel]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
#debug(skip, "skip", self.rank)
# mix_param: [bs, 13799, 3 * num_mixtures]
# Reshape skip to [bs, len, res_channel]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
mix_parameters
=
self
.
fc3
(
self
.
fc2
(
self
.
fc1
(
skip
)))
# Sample teacher-forced audio.
...
...
@@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer):
else
:
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
#debug(sample_audios, "sample_audios", self.rank)
# Calculate mix-gaussian density loss.
# padding is all zero.
# target_audio: [bs, 13799].
# mix_params: [bs, 13799, 3].
if
loss_type
==
"softmax"
:
loss
=
self
.
softmax_loss
(
target_audios
,
mix_parameters
)
elif
loss_type
==
"mix-gaussian-pdf"
:
...
...
@@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer):
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
#print("Rank {}, loss {}".format(self.rank, loss.numpy()))
return
loss
,
sample_audios
def
synthesize
(
self
,
mels
):
self
.
start_new_sequence
()
print
(
"input mels shape"
,
mels
.
shape
)
# mels: [bs=1, n_frames, 80]
# conditioner: [1, n_frames * samples_per_frame, 80]
# Should I move forward by one sample? No difference
# Append context frame to mels
bs
,
n_frames
,
mel_bands
=
mels
.
shape
#num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift))
#silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32")
#inf_mels = fluid.layers.concat([silence, mels], axis=1)
#print("padded mels shape", inf_mels.shape)
#conditioner = self.conditioner(inf_mels)[:, self.context_size:, :]
conditioner
=
self
.
conditioner
(
mels
)
time_steps
=
conditioner
.
shape
[
1
]
print
(
"Total steps"
,
time_steps
)
print
(
"input mels shape"
,
mels
.
shape
)
print
(
"Total synthesis steps"
,
time_steps
)
loss_type
=
self
.
config
.
loss_type
audio_samples
=
[]
...
...
@@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer):
if
i
%
100
==
0
:
print
(
"Step"
,
i
)
#
c
onvert from real value sample to audio embedding.
#
[bs, 1, 128
]
#
C
onvert from real value sample to audio embedding.
#
audio_input: [bs, 1, channel
]
if
loss_type
==
"softmax"
:
current_sample
=
fluid
.
layers
.
clip
(
current_sample
,
min
=-
1.0
,
max
=
0.99999
)
...
...
@@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer):
raise
ValueError
(
"loss_type {} is unsupported!"
.
format
(
loss_type
))
# [bs, 128, 1, 1]
audio_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
audio_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# [bs, 80]
# [bs, channel, 1, 1]
audio_input
=
fluid
.
layers
.
unsqueeze
(
fluid
.
layers
.
transpose
(
audio_input
,
perm
=
[
0
,
2
,
1
]),
2
)
# [bs, mel_bands]
cond_input
=
conditioner
[:,
i
,
:]
# [bs,
80
, 1, 1]
# [bs,
mel_bands
, 1, 1]
cond_input
=
fluid
.
layers
.
reshape
(
cond_input
,
cond_input
.
shape
+
[
1
,
1
])
skip
=
None
for
layer
in
self
.
dilated_causal_convs
:
audio_input
,
skip
=
layer
.
add_input
(
audio_input
,
skip
,
cond_input
)
audio_input
,
skip
=
layer
.
add_input
(
audio_input
,
skip
,
cond_input
)
# [bs, 1,
128
]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
# [bs, 1, 3]
# [bs, 1,
channel
]
skip
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
squeeze
(
skip
,
[
2
]),
perm
=
[
0
,
2
,
1
])
mix_parameters
=
self
.
fc3
(
self
.
fc2
(
self
.
fc1
(
skip
)))
if
loss_type
==
"softmax"
:
sample
=
self
.
sample_softmax
(
mix_parameters
)
...
...
@@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer):
current_sample
=
fluid
.
layers
.
reshape
(
current_sample
,
current_sample
.
shape
+
[
1
,
1
])
# syn_audio:
(num_samples,)
# syn_audio:
[num_samples]
syn_audio
=
fluid
.
layers
.
concat
(
audio_samples
,
axis
=
0
).
numpy
()
return
syn_audio
def
start_new_sequence
(
self
):
for
layer
in
self
.
sublayers
():
if
isinstance
(
layer
,
weight_norm
.
Conv1D
):
if
isinstance
(
layer
,
conv
.
Conv1D
):
layer
.
start_new_sequence
()
def
save
(
self
,
iteration
):
utils
.
save_latest_parameters
(
self
.
checkpoint_dir
,
iteration
,
self
.
wavenet
,
self
.
optimizer
)
utils
.
save_latest_checkpoint
(
self
.
checkpoint_dir
,
iteration
)
parakeet/models/wavenet/weight_norm.py
已删除
100644 → 0
浏览文件 @
b15c3134
import
math
from
copy
import
deepcopy
import
numpy
as
np
import
paddle.fluid.dygraph
as
dg
from
paddle
import
fluid
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.initializer
import
Normal
,
Constant
,
NumpyArrayInitializer
from
paddle.fluid.layers
import
utils
from
six.moves
import
reduce
def
_norm
(
p
,
dim
):
"""Computes the norm over all dimensions except dim.
It differs from pytorch implementation that it does not keep dim.
This difference is related with the broadcast mechanism in paddle.
Read elementeise_mul for more.
"""
if
dim
is
None
:
return
np
.
linalg
.
norm
(
p
,
ord
=
2
,
axis
=
None
)
elif
dim
==
0
:
p
=
np
.
reshape
(
p
,
newshape
=
(
p
.
shape
[
0
],
-
1
))
return
np
.
linalg
.
norm
(
p
,
ord
=
2
,
axis
=
1
)
elif
dim
==
p
.
ndim
-
1
:
p
=
np
.
reshape
(
p
,
newshape
=
(
-
1
,
p
.
shape
[
-
1
]))
return
np
.
linalg
.
norm
(
p
,
ord
=
2
,
axis
=
0
)
else
:
perm
=
list
(
range
(
p
.
ndim
))
perm
[
0
]
=
dim
perm
[
dim
]
=
0
return
_norm
(
np
.
transpose
(
p
,
axes
=
perm
))
class
Conv1D
(
dg
.
Layer
):
"""
A convolution 1D block implemented with Conv2D. Form simplicity and
ensuring the output has the same length as the input, it does not allow
stride > 1.
"""
def
__init__
(
self
,
name_scope
,
num_filters
,
filter_size
=
3
,
dilation
=
1
,
groups
=
None
,
causal
=
False
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
super
(
Conv1D
,
self
).
__init__
(
name_scope
,
dtype
=
dtype
)
if
causal
:
padding
=
dilation
*
(
filter_size
-
1
)
else
:
padding
=
(
dilation
*
(
filter_size
-
1
))
//
2
self
.
num_filters
=
num_filters
self
.
filter_size
=
filter_size
self
.
dilation
=
dilation
self
.
causal
=
causal
self
.
padding
=
padding
self
.
act
=
act
self
.
conv
=
Conv2D
(
self
.
full_name
(),
num_filters
=
num_filters
,
filter_size
=
(
1
,
filter_size
),
stride
=
(
1
,
1
),
dilation
=
(
1
,
dilation
),
padding
=
(
0
,
padding
),
groups
=
groups
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
output channels (num_filters).
"""
x
=
self
.
conv
(
x
)
if
self
.
filter_size
>
1
:
if
self
.
causal
:
x
=
fluid
.
layers
.
slice
(
x
,
axes
=
[
3
],
starts
=
[
0
],
ends
=
[
-
self
.
padding
])
elif
self
.
filter_size
%
2
==
0
:
x
=
fluid
.
layers
.
slice
(
x
,
axes
=
[
3
],
starts
=
[
0
],
ends
=
[
-
1
])
return
x
def
start_new_sequence
(
self
):
self
.
temp_weight
=
None
self
.
input_buffer
=
None
def
add_input
(
self
,
x
):
"""
Adding input for a time step and compute an output for a time step.
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels, and T = 1.
Returns:
out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
means output channels (num_filters), and T = 1.
"""
if
self
.
temp_weight
is
None
:
self
.
temp_weight
=
self
.
_reshaped_weight
()
window_size
=
1
+
(
self
.
filter_size
-
1
)
*
self
.
dilation
batch_size
=
x
.
shape
[
0
]
in_channels
=
x
.
shape
[
1
]
if
self
.
filter_size
>
1
:
if
self
.
input_buffer
is
None
:
self
.
input_buffer
=
fluid
.
layers
.
fill_constant
(
[
batch_size
,
in_channels
,
1
,
window_size
-
1
],
dtype
=
x
.
dtype
,
value
=
0.0
)
else
:
self
.
input_buffer
=
self
.
input_buffer
[:,
:,
:,
1
:]
self
.
input_buffer
=
fluid
.
layers
.
concat
(
[
self
.
input_buffer
,
x
],
axis
=
3
)
x
=
self
.
input_buffer
if
self
.
dilation
>
1
:
if
not
hasattr
(
self
,
"indices"
):
self
.
indices
=
dg
.
to_variable
(
np
.
arange
(
0
,
window_size
,
self
.
dilation
))
tmp
=
fluid
.
layers
.
transpose
(
self
.
input_buffer
,
perm
=
[
3
,
1
,
2
,
0
])
tmp
=
fluid
.
layers
.
gather
(
tmp
,
index
=
self
.
indices
)
tmp
=
fluid
.
layers
.
transpose
(
tmp
,
perm
=
[
3
,
1
,
2
,
0
])
x
=
tmp
inputs
=
fluid
.
layers
.
reshape
(
x
,
shape
=
[
batch_size
,
in_channels
*
1
*
self
.
filter_size
])
out
=
fluid
.
layers
.
matmul
(
inputs
,
self
.
temp_weight
,
transpose_y
=
True
)
out
=
fluid
.
layers
.
elementwise_add
(
out
,
self
.
conv
.
_bias_param
,
axis
=-
1
)
out
=
fluid
.
layers
.
reshape
(
out
,
out
.
shape
+
[
1
,
1
])
out
=
self
.
_helper
.
append_activation
(
out
,
act
=
self
.
act
)
return
out
def
_reshaped_weight
(
self
):
"""
Get the linearized weight of convolution filter, cause it is by nature
a matmul weight. And because the model uses weight norm, compute the
weight by weight_v * weight_g to make it faster.
Returns:
weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
"""
shape
=
self
.
conv
.
_filter_param_v
.
shape
matrix_shape
=
[
shape
[
0
],
np
.
prod
(
shape
[
1
:])]
weight_matrix
=
fluid
.
layers
.
reshape
(
self
.
conv
.
_filter_param_v
,
shape
=
matrix_shape
)
weight_matrix
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
l2_normalize
(
weight_matrix
,
axis
=
1
),
self
.
conv
.
_filter_param_g
,
axis
=
0
)
return
weight_matrix
class
FC
(
dg
.
Layer
):
"""
**Fully Connected Layer**
This function creates a fully connected layer in the network. It can take
one or multiple tensors as its inputs(input can be a list of Variable, see
Args in detail). It creates a pair of variables called (magnitude(g),
direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected
weight matrix from each input unit to each output unit.
The fully connected layer multiplies each input tensor
with its corresponding weight to produce an output Tensor with shape [M, `size`],
where M is batch size. If multiple input tensors are given, the results of
multiple output tensors with shape [M, `size`] will be summed up. If bias_attr
is not None, a bias variable will be created and added to the output.
Finally, if activation is not None, it will be applied to the output as well.
When the input is single tensor:
.. math::
Out = Act({X(normalize(V)g) + b})
When the input are multiple tensors:
.. math::
Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b})
In the above equation:
* :math:`N`: Number of the input. N equals to len(input) if input is list of Variable.
* :math:`X_i`: The i-th input tensor.
* :math:`V_i`: The i-th direction matrix corresponding i-th input tensor.
* :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation function.
* :math:`Out`: The output tensor.
See below for an example.
.. code-block:: text
Given:
data_1.data = [[[0.1, 0.2],
[0.3, 0.4]]]
data_1.shape = (1, 2, 2) # 1 is batch_size
data_2 = [[[0.1, 0.2, 0.3]]]
data_2.shape = (1, 1, 3)
out = fluid.layers.fc(input=[data_1, data_2], size=2)
Then:
out.data = [[0.18669507, 0.1893476]]
out.shape = (1, 2)
Args:
name_scope(str): The name of this class.
size(int): The number of output units in this layer.
num_flatten_dims (int): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened
into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
dimensions will be flatten to form the first dimension of the final matrix (height of
the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
form the second dimension of the final matrix (width of the matrix). For example, suppose
`X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1
param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str|None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase. Default: False
dtype(str): Dtype used for weight
Raises:
ValueError: If rank of the input tensor is less than 2.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
with fluid.dygraph.guard():
fc = FC( "fc", 64, num_flatten_dims=2)
data = to_variable( data )
conv = fc( data )
"""
def
__init__
(
self
,
name_scope
,
size
,
num_flatten_dims
=
1
,
epsilon
=
1e-30
,
param_attr
=
None
,
bias_attr
=
None
,
act
=
None
,
is_test
=
False
,
dtype
=
"float32"
):
super
(
FC
,
self
).
__init__
(
name_scope
,
dtype
)
self
.
_size
=
size
self
.
_num_flatten_dims
=
num_flatten_dims
self
.
_epsilon
=
epsilon
self
.
_dtype
=
dtype
self
.
_param_attr
=
param_attr
self
.
_bias_attr
=
bias_attr
self
.
_act
=
act
self
.
__g
=
list
()
self
.
__v
=
list
()
@
property
def
_v
(
self
,
i
=
0
):
return
self
.
__v
[
i
]
@
property
def
_g
(
self
,
i
=
0
):
return
self
.
__g
[
i
]
@
_v
.
setter
def
_v
(
self
,
value
,
i
=
0
):
assert
isinstance
(
value
,
Parameter
)
self
.
__v
[
i
]
=
value
@
_g
.
setter
def
_g
(
self
,
value
,
i
=
0
):
assert
isinstance
(
value
,
Parameter
)
self
.
__g
[
i
]
=
value
def
_build_once
(
self
,
input
):
i
=
0
for
inp
,
param
in
self
.
_helper
.
iter_inputs_and_params
(
input
,
self
.
_param_attr
):
input_shape
=
inp
.
shape
param_shape
=
[
reduce
(
lambda
a
,
b
:
a
*
b
,
input_shape
[
self
.
_num_flatten_dims
:],
1
)
]
+
[
self
.
_size
]
self
.
__v
.
append
(
self
.
add_parameter
(
"_v%d"
%
i
,
self
.
create_parameter
(
attr
=
param
,
shape
=
param_shape
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)))
magnitude_shape
=
param_shape
[
1
:]
magnitude_value
=
np
.
linalg
.
norm
(
self
.
__v
[
i
].
numpy
(),
ord
=
2
,
axis
=
0
)
self
.
__g
.
append
(
self
.
add_parameter
(
"_g%d"
%
i
,
self
.
create_parameter
(
attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
magnitude_value
)),
shape
=
magnitude_shape
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)))
i
+=
1
size
=
list
([
self
.
_size
])
self
.
_b
=
self
.
create_parameter
(
attr
=
self
.
_bias_attr
,
shape
=
size
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
):
mul_results
=
list
()
i
=
0
for
inp
,
param
in
self
.
_helper
.
iter_inputs_and_params
(
input
,
self
.
_param_attr
):
v_norm
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
v_normalized
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"norm"
,
inputs
=
{
"X"
:
self
.
__v
[
i
]},
outputs
=
{
"Out"
:
v_normalized
,
"Norm"
:
v_norm
},
attrs
=
{
"axis"
:
0
,
"epsilon"
:
self
.
_epsilon
})
weight
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_mul"
,
inputs
=
{
"X"
:
[
v_normalized
],
"Y"
:
[
self
.
__g
[
i
]]},
outputs
=
{
"Out"
:
[
weight
]},
attrs
=
{
"axis"
:
1
})
tmp
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"mul"
,
inputs
=
{
"X"
:
inp
,
"Y"
:
weight
},
outputs
=
{
"Out"
:
tmp
},
attrs
=
{
"x_num_col_dims"
:
self
.
_num_flatten_dims
,
"y_num_col_dims"
:
1
})
i
+=
1
mul_results
.
append
(
tmp
)
if
len
(
mul_results
)
==
1
:
pre_bias
=
mul_results
[
0
]
else
:
pre_bias
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
mul_results
},
outputs
=
{
"Out"
:
pre_bias
},
attrs
=
{
"use_mkldnn"
:
False
})
if
self
.
_b
:
pre_activation
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
pre_bias
],
"Y"
:
[
self
.
_b
]},
outputs
=
{
"Out"
:
[
pre_activation
]},
attrs
=
{
"axis"
:
self
.
_num_flatten_dims
})
else
:
pre_activation
=
pre_bias
# Currently, we don't support inplace in dygraph mode
return
self
.
_helper
.
append_activation
(
pre_activation
,
act
=
self
.
_act
)
class
Conv2D
(
dg
.
Layer
):
"""
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature.
Filter is in MCHW format, where M is the number of output image channels,
C is the number of input image channels, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input image channels divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`
for more detials.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg)
\\
ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter direction value, a tensor with MCHW format.
* :math:`g`: Filter magnitude value, a tensor with M format.
* :math:`
\\
ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`
\\
sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&=
\\
frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
\\\\
W_{out}&=
\\
frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
name_scope(str) : The name for this class.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(
\\
frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
"""
def
__init__
(
self
,
name_scope
,
num_filters
,
filter_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
None
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
True
,
act
=
None
,
epsilon
=
1e-30
,
dtype
=
"float32"
):
assert
param_attr
is
not
False
,
"param_attr should not be False here."
super
(
Conv2D
,
self
).
__init__
(
name_scope
,
dtype
)
self
.
_groups
=
groups
self
.
_stride
=
utils
.
convert_to_list
(
stride
,
2
,
"stride"
)
self
.
_padding
=
utils
.
convert_to_list
(
padding
,
2
,
"padding"
)
self
.
_dilation
=
utils
.
convert_to_list
(
dilation
,
2
,
"dilation"
)
self
.
_act
=
act
if
not
isinstance
(
use_cudnn
,
bool
):
raise
ValueError
(
"use_cudnn should be True or False"
)
self
.
_use_cudnn
=
use_cudnn
self
.
_filter_size
=
filter_size
self
.
_num_filters
=
num_filters
self
.
_param_attr
=
param_attr
self
.
_bias_attr
=
bias_attr
self
.
_epsilon
=
epsilon
self
.
_dtype
=
dtype
# if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d'
# else:
# TODO(jiabin): recover the usage of depthwise_conv2d when it's
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
self
.
_l_type
=
"conv2d"
def
_build_once
(
self
,
input
):
self
.
_num_channels
=
input
.
shape
[
1
]
if
self
.
_groups
is
None
:
num_filter_channels
=
self
.
_num_channels
else
:
if
self
.
_num_channels
%
self
.
_groups
!=
0
:
raise
ValueError
(
"num_channels must be divisible by groups."
)
num_filter_channels
=
self
.
_num_channels
//
self
.
_groups
filter_size
=
utils
.
convert_to_list
(
self
.
_filter_size
,
2
,
"filter_size"
)
filter_shape
=
[
self
.
_num_filters
,
int
(
num_filter_channels
)
]
+
filter_size
def
_get_default_param_initializer
():
filter_elem_num
=
filter_size
[
0
]
*
filter_size
[
1
]
*
self
.
_num_channels
std
=
(
2.0
/
filter_elem_num
)
**
0.5
return
Normal
(
0.0
,
std
,
0
)
# weight_v
self
.
_filter_param_v
=
self
.
create_parameter
(
attr
=
self
.
_param_attr
,
shape
=
filter_shape
,
dtype
=
self
.
_dtype
,
default_initializer
=
_get_default_param_initializer
())
# weight_g
norm_value
=
_norm
(
self
.
_filter_param_v
.
numpy
(),
dim
=
0
)
# CAUTION: hard-code
self
.
_filter_param_g
=
self
.
create_parameter
(
attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
norm_value
)),
shape
=
norm_value
.
shape
,
dtype
=
self
.
_dtype
,
default_initializer
=
_get_default_param_initializer
())
if
self
.
_use_cudnn
:
self
.
create_variable
(
name
=
"kCUDNNFwdAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
create_variable
(
name
=
"kCUDNNBwdDataAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
create_variable
(
name
=
"kCUDNNBwdFilterAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
_bias_param
=
self
.
create_parameter
(
attr
=
self
.
_bias_attr
,
shape
=
[
self
.
_num_filters
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
):
matrix
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
tmp
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
new_shape
=
[
self
.
_filter_param_v
.
shape
[
0
],
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_filter_param_v
.
shape
[
1
:],
1
),
]
self
.
_helper
.
append_op
(
type
=
"reshape2"
,
inputs
=
{
"X"
:
self
.
_filter_param_v
},
attrs
=
{
"shape"
:
new_shape
},
outputs
=
{
"Out"
:
matrix
,
"XShape"
:
tmp
})
m_norm
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
m_normalized
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"norm"
,
inputs
=
{
"X"
:
matrix
},
outputs
=
{
"Out"
:
m_normalized
,
"Norm"
:
m_norm
},
attrs
=
{
"axis"
:
1
,
"epsilon"
:
self
.
_epsilon
})
v_normalized
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
tmp2
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"reshape2"
,
inputs
=
{
"X"
:
m_normalized
},
attrs
=
{
"shape"
:
self
.
_filter_param_v
.
shape
},
outputs
=
{
"Out"
:
v_normalized
,
"XShape"
:
tmp2
})
filter_param
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_mul"
,
inputs
=
{
"X"
:
[
v_normalized
],
"Y"
:
[
self
.
_filter_param_g
]},
outputs
=
{
"Out"
:
[
filter_param
]},
attrs
=
{
"axis"
:
0
},
# CAUTION: hard-code
)
pre_bias
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
self
.
_l_type
,
inputs
=
{
"Input"
:
input
,
"Filter"
:
filter_param
},
outputs
=
{
"Output"
:
pre_bias
},
attrs
=
{
"strides"
:
self
.
_stride
,
"paddings"
:
self
.
_padding
,
"dilations"
:
self
.
_dilation
,
"groups"
:
self
.
_groups
if
self
.
_groups
else
1
,
"use_cudnn"
:
self
.
_use_cudnn
,
"use_mkldnn"
:
False
,
})
if
self
.
_bias_param
is
not
None
:
pre_act
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
pre_bias
],
"Y"
:
[
self
.
_bias_param
]},
outputs
=
{
"Out"
:
[
pre_act
]},
attrs
=
{
"axis"
:
1
})
else
:
pre_act
=
pre_bias
# Currently, we don't support inplace in dygraph mode
return
self
.
_helper
.
append_activation
(
pre_act
,
act
=
self
.
_act
)
class
Conv2DTranspose
(
dg
.
Layer
):
"""
**Convlution2D transpose layer**
The convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input(Input) and output(Output)
are in NCHW format. Where N is batch size, C is the number of channels,
H is the height of the feature, and W is the width of the feature.
Parameters(dilations, strides, paddings) are two elements. These two elements
represent height and width, respectively. The details of convolution transpose
layer, please refer to the following explanation and references
`therein <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg)
\\
ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter value, a tensor with MCHW format.
* :math:`g`: Filter value, a tensor with M format.
* :math:`
\\
ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`
\\
sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1
\\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1
\\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] )
\\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Args:
name_scope(str): The name of this class.
num_filters(int): The number of the filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. None if use output size to
calculate filter_size. Default: None.
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
Returns:
Variable: The tensor variable storing the convolution transpose result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
with fluid.dygraph.guard():
data = numpy.random.random((3, 32, 32)).astype('float32')
conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose(
'Conv2DTranspose', num_filters=2, filter_size=3)
ret = conv2DTranspose(fluid.dygraph.base.to_variable(data))
"""
def
__init__
(
self
,
name_scope
,
num_filters
,
output_size
=
None
,
filter_size
=
None
,
padding
=
0
,
stride
=
1
,
dilation
=
1
,
groups
=
None
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
True
,
epsilon
=
1e-30
,
act
=
None
,
dtype
=
"float32"
):
super
(
Conv2DTranspose
,
self
).
__init__
(
name_scope
,
dtype
)
assert
(
param_attr
is
not
False
),
"param_attr should not be False in conv2d_transpose."
self
.
_param_attr
=
param_attr
self
.
_bias_attr
=
bias_attr
self
.
_groups
=
groups
self
.
_num_filters
=
num_filters
self
.
_use_cudnn
=
use_cudnn
self
.
_padding
=
padding
self
.
_stride
=
stride
self
.
_dilation
=
dilation
self
.
_filter_size
=
filter_size
self
.
_output_size
=
output_size
self
.
_op_type
=
"conv2d_transpose"
self
.
_epsilon
=
epsilon
def
_build_once
(
self
,
input
):
input_channel
=
input
.
shape
[
1
]
if
(
input_channel
==
self
.
_groups
and
self
.
_num_filters
==
input_channel
and
not
self
.
_use_cudnn
):
self
.
_op_type
=
"depthwise_conv2d_transpose"
if
not
isinstance
(
input
,
Variable
):
raise
TypeError
(
"Input of conv2d_transpose must be Variable"
)
self
.
_padding
=
utils
.
convert_to_list
(
self
.
_padding
,
2
,
"padding"
)
self
.
_stride
=
utils
.
convert_to_list
(
self
.
_stride
,
2
,
"stride"
)
self
.
_dilation
=
utils
.
convert_to_list
(
self
.
_dilation
,
2
,
"dilation"
)
if
not
isinstance
(
self
.
_use_cudnn
,
bool
):
raise
ValueError
(
"use_cudnn should be True or False"
)
if
self
.
_filter_size
is
None
:
if
self
.
_output_size
is
None
:
raise
ValueError
(
"output_size must be set when filter_size is None"
)
if
isinstance
(
self
.
_output_size
,
int
):
self
.
_output_size
=
[
self
.
_output_size
,
self
.
_output_size
]
h_in
=
input
.
shape
[
2
]
w_in
=
input
.
shape
[
3
]
filter_size_h
=
(
self
.
_output_size
[
0
]
-
(
h_in
-
1
)
*
self
.
_stride
[
0
]
+
2
*
self
.
_padding
[
0
]
-
1
)
//
self
.
_dilation
[
0
]
+
1
filter_size_w
=
(
self
.
_output_size
[
1
]
-
(
w_in
-
1
)
*
self
.
_stride
[
1
]
+
2
*
self
.
_padding
[
1
]
-
1
)
//
self
.
_dilation
[
1
]
+
1
self
.
_filter_size
=
[
filter_size_h
,
filter_size_w
]
else
:
self
.
_filter_size
=
utils
.
convert_to_list
(
self
.
_filter_size
,
2
,
"conv2d_transpose.filter_size"
)
if
self
.
_output_size
is
None
:
self
.
_output_size
=
[]
elif
isinstance
(
self
.
_output_size
,
list
)
or
isinstance
(
self
.
_output_size
,
int
):
self
.
_output_size
=
utils
.
convert_to_list
(
self
.
_output_size
,
2
,
"output_size"
)
else
:
raise
ValueError
(
"output_size should be list or int"
)
self
.
_padding
=
utils
.
convert_to_list
(
self
.
_padding
,
2
,
"padding"
)
self
.
_groups
=
1
if
self
.
_groups
is
None
else
self
.
_groups
filter_shape
=
[
input_channel
,
self
.
_num_filters
//
self
.
_groups
,
]
+
self
.
_filter_size
# img filter v (direction)
self
.
_img_filter_v
=
self
.
create_parameter
(
dtype
=
input
.
dtype
,
shape
=
filter_shape
,
attr
=
self
.
_param_attr
)
# img filter g (magnitude)
img_filter_magnitude
=
_norm
(
self
.
_img_filter_v
.
numpy
(),
dim
=
0
)
# CAUTION: hard-code
self
.
_img_filter_g
=
self
.
create_parameter
(
dtype
=
input
.
dtype
,
shape
=
img_filter_magnitude
.
shape
,
attr
=
fluid
.
ParamAttr
(
initializer
=
NumpyArrayInitializer
(
img_filter_magnitude
)))
self
.
_img_bias
=
self
.
create_parameter
(
attr
=
self
.
_bias_attr
,
shape
=
[
self
.
_num_filters
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
):
matrix
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
tmp
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
new_shape
=
[
self
.
_img_filter_v
.
shape
[
0
],
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
_img_filter_v
.
shape
[
1
:],
1
),
]
self
.
_helper
.
append_op
(
type
=
"reshape2"
,
inputs
=
{
"X"
:
self
.
_img_filter_v
},
attrs
=
{
"shape"
:
new_shape
},
outputs
=
{
"Out"
:
matrix
,
"XShape"
:
tmp
})
m_norm
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
m_normalized
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"norm"
,
inputs
=
{
"X"
:
matrix
},
outputs
=
{
"Out"
:
m_normalized
,
"Norm"
:
m_norm
},
attrs
=
{
"axis"
:
1
,
"epsilon"
:
self
.
_epsilon
})
v_normalized
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
tmp2
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"reshape2"
,
inputs
=
{
"X"
:
m_normalized
},
attrs
=
{
"shape"
:
self
.
_img_filter_v
.
shape
},
outputs
=
{
"Out"
:
v_normalized
,
"XShape"
:
tmp2
})
img_filter
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_mul"
,
inputs
=
{
"X"
:
[
v_normalized
],
"Y"
:
[
self
.
_img_filter_g
]},
outputs
=
{
"Out"
:
[
img_filter
]},
attrs
=
{
"axis"
:
0
},
# CAUTION: hard-code
)
pre_bias
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
self
.
_helper
.
append_op
(
type
=
self
.
_op_type
,
inputs
=
{
"Input"
:
[
input
],
"Filter"
:
[
img_filter
]},
outputs
=
{
"Output"
:
pre_bias
},
attrs
=
{
"output_size"
:
self
.
_output_size
,
"strides"
:
self
.
_stride
,
"paddings"
:
self
.
_padding
,
"dilations"
:
self
.
_dilation
,
"groups"
:
self
.
_groups
,
"use_cudnn"
:
self
.
_use_cudnn
,
})
if
self
.
_img_bias
is
not
None
:
pre_act
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
pre_bias
],
"Y"
:
[
self
.
_img_bias
]},
outputs
=
{
"Out"
:
[
pre_act
]},
attrs
=
{
"axis"
:
1
})
else
:
pre_act
=
pre_bias
out
=
self
.
_helper
.
append_activation
(
pre_act
)
return
out
parakeet/modules/modules.py
浏览文件 @
98841ee4
...
...
@@ -26,6 +26,7 @@ def FC(name_scope,
in_features
,
size
,
num_flatten_dims
=
1
,
relu
=
False
,
dropout
=
0.0
,
epsilon
=
1e-30
,
act
=
None
,
...
...
@@ -39,7 +40,11 @@ def FC(name_scope,
# stds
if
isinstance
(
in_features
,
int
):
in_features
=
[
in_features
]
stds
=
[
np
.
sqrt
((
1
-
dropout
)
/
in_feature
)
for
in_feature
in
in_features
]
if
relu
:
stds
=
[
std
*
np
.
sqrt
(
2.0
)
for
std
in
stds
]
weight_inits
=
[
fluid
.
initializer
.
NormalInitializer
(
scale
=
std
)
for
std
in
stds
]
...
...
@@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer):
return
out
else
:
raise
Exception
(
"Then you can just use position rate at init"
)
class
Conv1D_GU
(
dg
.
Layer
):
def
__init__
(
self
,
name_scope
,
conditioner_dim
,
in_channels
,
num_filters
,
filter_size
,
dilation
,
causal
=
False
,
residual
=
True
,
dtype
=
"float32"
):
super
(
Conv1D_GU
,
self
).
__init__
(
name_scope
,
dtype
=
dtype
)
self
.
conditioner_dim
=
conditioner_dim
self
.
in_channels
=
in_channels
self
.
num_filters
=
num_filters
self
.
filter_size
=
filter_size
self
.
dilation
=
dilation
self
.
causal
=
causal
self
.
residual
=
residual
if
residual
:
assert
(
in_channels
==
num_filters
),
"this block uses residual connection"
\
"the input_channels should equals num_filters"
self
.
conv
=
Conv1D
(
self
.
full_name
(),
in_channels
,
2
*
num_filters
,
filter_size
,
dilation
,
causal
=
causal
,
dtype
=
dtype
)
self
.
fc
=
Conv1D
(
self
.
full_name
(),
conditioner_dim
,
2
*
num_filters
,
filter_size
=
1
,
dilation
=
1
,
causal
=
False
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
skip (Variable): Shape(B, C_in, 1, T), skip connection.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1D_GU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
C_out means the output channels of Conv1D_GU.
skip (Variable): Shape(B, C_out, 1, T), skip connection.
"""
residual
=
x
x
=
self
.
conv
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
add_input
(
self
,
x
,
skip
=
None
,
conditioner
=
None
):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
skip: shape(B, num_filters, 1, time_steps), skip connection
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
skip: skip connection, same shape as x
"""
residual
=
x
# add step input and produce step output
x
=
self
.
conv
.
add_input
(
x
)
if
conditioner
is
not
None
:
cond_bias
=
self
.
fc
(
conditioner
)
x
+=
cond_bias
content
,
gate
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
2
,
dim
=
1
)
# Gated Unit.
x
=
fluid
.
layers
.
elementwise_mul
(
fluid
.
layers
.
sigmoid
(
gate
),
fluid
.
layers
.
tanh
(
content
))
if
skip
is
None
:
skip
=
x
else
:
skip
=
fluid
.
layers
.
scale
(
skip
+
x
,
np
.
sqrt
(
0.5
))
if
self
.
residual
:
x
=
fluid
.
layers
.
scale
(
residual
+
x
,
np
.
sqrt
(
0.5
))
return
x
,
skip
def
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
,
padding
=
0
,
stride
=
1
,
dilation
=
1
,
use_cudnn
=
True
,
act
=
None
,
dtype
=
"float32"
):
val
=
1.0
/
(
filter_size
[
0
]
*
filter_size
[
1
])
weight_init
=
fluid
.
initializer
.
ConstantInitializer
(
val
)
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
weight_init
)
layer
=
weight_norm
.
Conv2DTranspose
(
name_scope
,
num_filters
,
filter_size
=
filter_size
,
padding
=
padding
,
stride
=
stride
,
dilation
=
dilation
,
param_attr
=
weight_attr
,
use_cudnn
=
use_cudnn
,
act
=
act
,
dtype
=
dtype
)
return
layer
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录