Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
429695d6
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
14
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
429695d6
编写于
3月 09, 2020
作者:
L
lifuchen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add docstring to transformer_tts and fastspeech
上级
3d1fda0c
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
326 addition
and
234 deletion
+326
-234
parakeet/models/fastspeech/decoder.py
parakeet/models/fastspeech/decoder.py
+27
-19
parakeet/models/fastspeech/encoder.py
parakeet/models/fastspeech/encoder.py
+29
-19
parakeet/models/fastspeech/fastspeech.py
parakeet/models/fastspeech/fastspeech.py
+26
-30
parakeet/models/fastspeech/fft_block.py
parakeet/models/fastspeech/fft_block.py
+22
-12
parakeet/models/fastspeech/length_regulator.py
parakeet/models/fastspeech/length_regulator.py
+27
-12
parakeet/models/transformer_tts/cbhg.py
parakeet/models/transformer_tts/cbhg.py
+26
-9
parakeet/models/transformer_tts/decoder.py
parakeet/models/transformer_tts/decoder.py
+25
-29
parakeet/models/transformer_tts/encoder.py
parakeet/models/transformer_tts/encoder.py
+17
-14
parakeet/models/transformer_tts/encoderprenet.py
parakeet/models/transformer_tts/encoderprenet.py
+11
-5
parakeet/models/transformer_tts/post_convnet.py
parakeet/models/transformer_tts/post_convnet.py
+17
-4
parakeet/models/transformer_tts/prenet.py
parakeet/models/transformer_tts/prenet.py
+14
-6
parakeet/models/transformer_tts/transformer_tts.py
parakeet/models/transformer_tts/transformer_tts.py
+24
-34
parakeet/models/transformer_tts/vocoder.py
parakeet/models/transformer_tts/vocoder.py
+10
-5
parakeet/modules/dynamic_gru.py
parakeet/modules/dynamic_gru.py
+3
-3
parakeet/modules/ffn.py
parakeet/modules/ffn.py
+14
-6
parakeet/modules/multihead_attention.py
parakeet/modules/multihead_attention.py
+34
-27
未找到文件。
parakeet/models/fastspeech/decoder.py
浏览文件 @
429695d6
...
@@ -23,12 +23,26 @@ class Decoder(dg.Layer):
...
@@ -23,12 +23,26 @@ class Decoder(dg.Layer):
n_layers
,
n_layers
,
n_head
,
n_head
,
d_k
,
d_k
,
d_
v
,
d_
q
,
d_model
,
d_model
,
d_inner
,
d_inner
,
fft_conv1d_kernel
,
fft_conv1d_kernel
,
fft_conv1d_padding
,
fft_conv1d_padding
,
dropout
=
0.1
):
dropout
=
0.1
):
"""Decoder layer of FastSpeech.
Args:
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super
(
Decoder
,
self
).
__init__
()
super
(
Decoder
,
self
).
__init__
()
n_position
=
len_max_seq
+
1
n_position
=
len_max_seq
+
1
...
@@ -48,7 +62,7 @@ class Decoder(dg.Layer):
...
@@ -48,7 +62,7 @@ class Decoder(dg.Layer):
d_inner
,
d_inner
,
n_head
,
n_head
,
d_k
,
d_k
,
d_
v
,
d_
q
,
fft_conv1d_kernel
,
fft_conv1d_kernel
,
fft_conv1d_padding
,
fft_conv1d_padding
,
dropout
=
dropout
)
for
_
in
range
(
n_layers
)
dropout
=
dropout
)
for
_
in
range
(
n_layers
)
...
@@ -58,26 +72,20 @@ class Decoder(dg.Layer):
...
@@ -58,26 +72,20 @@ class Decoder(dg.Layer):
def
forward
(
self
,
enc_seq
,
enc_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
def
forward
(
self
,
enc_seq
,
enc_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
"""
"""
Decoder layer of FastSpeech.
Compute decoder outputs.
Args:
Args:
enc_seq (Variable): The output of length regulator.
enc_seq (Variable): shape(B, T_text, C), dtype float32,
Shape: (B, T_text, C), T_text means the timesteps of input text,
the output of length regulator, where T_text means the timesteps of input text,
dtype: float32.
enc_pos (Variable): shape(B, T_mel), dtype int64,
enc_pos (Variable): The spectrum position.
the spectrum position, where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel), T_mel means the timesteps of input spectrum,
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
dtype: int64.
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
non_pad_mask (Variable): the mask with non pad.
the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
Returns:
Returns:
dec_output (Variable): the decoder output.
dec_output (Variable): shape(B, T_mel, C), the decoder output.
Shape: (B, T_mel, C).
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
dec_slf_attn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
"""
dec_slf_attn_list
=
[]
dec_slf_attn_list
=
[]
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
...
...
parakeet/models/fastspeech/encoder.py
浏览文件 @
429695d6
...
@@ -24,12 +24,27 @@ class Encoder(dg.Layer):
...
@@ -24,12 +24,27 @@ class Encoder(dg.Layer):
n_layers
,
n_layers
,
n_head
,
n_head
,
d_k
,
d_k
,
d_
v
,
d_
q
,
d_model
,
d_model
,
d_inner
,
d_inner
,
fft_conv1d_kernel
,
fft_conv1d_kernel
,
fft_conv1d_padding
,
fft_conv1d_padding
,
dropout
=
0.1
):
dropout
=
0.1
):
"""Encoder layer of FastSpeech.
Args:
n_src_vocab (int): the number of source vocabulary.
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
n_position
=
len_max_seq
+
1
n_position
=
len_max_seq
+
1
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -53,7 +68,7 @@ class Encoder(dg.Layer):
...
@@ -53,7 +68,7 @@ class Encoder(dg.Layer):
d_inner
,
d_inner
,
n_head
,
n_head
,
d_k
,
d_k
,
d_
v
,
d_
q
,
fft_conv1d_kernel
,
fft_conv1d_kernel
,
fft_conv1d_padding
,
fft_conv1d_padding
,
dropout
=
dropout
)
for
_
in
range
(
n_layers
)
dropout
=
dropout
)
for
_
in
range
(
n_layers
)
...
@@ -63,25 +78,20 @@ class Encoder(dg.Layer):
...
@@ -63,25 +78,20 @@ class Encoder(dg.Layer):
def
forward
(
self
,
character
,
text_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
def
forward
(
self
,
character
,
text_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
"""
"""
Encoder layer of FastSpeech.
Encode text sequence.
Args:
Args:
character (Variable): The input text characters.
character (Variable): shape(B, T_text), dtype float32, the input text characters,
Shape: (B, T_text), T_text means the timesteps of input characters,
where T_text means the timesteps of input characters,
dtype: float32.
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
text_pos (Variable): The input text position.
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
Shape: (B, T_text), dtype: int64.
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
non_pad_mask (Variable): the mask with non pad.
the mask of input characters. Defaults to None.
Shape: (B, T_text, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
Returns:
Returns:
enc_output (Variable), the encoder output. Shape(B, T_text, C)
enc_output (Variable): shape(B, T_text, C), the encoder output.
non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1)
non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list[Variable]), the encoder self attention list.
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
Len: n_layers.
"""
"""
enc_slf_attn_list
=
[]
enc_slf_attn_list
=
[]
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
...
...
parakeet/models/fastspeech/fastspeech.py
浏览文件 @
429695d6
...
@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder
...
@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder
class
FastSpeech
(
dg
.
Layer
):
class
FastSpeech
(
dg
.
Layer
):
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
cfg
):
" FastSpeech"
"""FastSpeech model.
Args:
cfg: the yaml configs used in FastSpeech model.
"""
super
(
FastSpeech
,
self
).
__init__
()
super
(
FastSpeech
,
self
).
__init__
()
self
.
encoder
=
Encoder
(
self
.
encoder
=
Encoder
(
...
@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer):
...
@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer):
n_layers
=
cfg
[
'encoder_n_layer'
],
n_layers
=
cfg
[
'encoder_n_layer'
],
n_head
=
cfg
[
'encoder_head'
],
n_head
=
cfg
[
'encoder_head'
],
d_k
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'encoder_head'
],
d_k
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'encoder_head'
],
d_
v
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'encoder_head'
],
d_
q
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'encoder_head'
],
d_model
=
cfg
[
'fs_hidden_size'
],
d_model
=
cfg
[
'fs_hidden_size'
],
d_inner
=
cfg
[
'encoder_conv1d_filter_size'
],
d_inner
=
cfg
[
'encoder_conv1d_filter_size'
],
fft_conv1d_kernel
=
cfg
[
'fft_conv1d_filter'
],
fft_conv1d_kernel
=
cfg
[
'fft_conv1d_filter'
],
...
@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer):
...
@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer):
n_layers
=
cfg
[
'decoder_n_layer'
],
n_layers
=
cfg
[
'decoder_n_layer'
],
n_head
=
cfg
[
'decoder_head'
],
n_head
=
cfg
[
'decoder_head'
],
d_k
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'decoder_head'
],
d_k
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'decoder_head'
],
d_
v
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'decoder_head'
],
d_
q
=
cfg
[
'fs_hidden_size'
]
//
cfg
[
'decoder_head'
],
d_model
=
cfg
[
'fs_hidden_size'
],
d_model
=
cfg
[
'fs_hidden_size'
],
d_inner
=
cfg
[
'decoder_conv1d_filter_size'
],
d_inner
=
cfg
[
'decoder_conv1d_filter_size'
],
fft_conv1d_kernel
=
cfg
[
'fft_conv1d_filter'
],
fft_conv1d_kernel
=
cfg
[
'fft_conv1d_filter'
],
...
@@ -88,39 +92,31 @@ class FastSpeech(dg.Layer):
...
@@ -88,39 +92,31 @@ class FastSpeech(dg.Layer):
length_target
=
None
,
length_target
=
None
,
alpha
=
1.0
):
alpha
=
1.0
):
"""
"""
FastSpeech model
.
Compute mel output from text character
.
Args:
Args:
character (Variable): The input text characters.
character (Variable): shape(B, T_text), dtype float32, the input text characters,
Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32.
where T_text means the timesteps of input characters,
text_pos (Variable): The input text position.
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
Shape: (B, T_text), dtype: int64.
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
mel_pos (Variable, optional): The spectrum position.
where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64.
enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
enc_non_pad_mask (Variable): the mask with non pad.
dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
Shape: (B, T_text, 1),
enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
dtype: int64.
the mask of input characters. Defaults to None.
dec_non_pad_mask (Variable): the mask with non pad.
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
Shape: (B, T_mel, 1),
the mask of mel spectrum. Defaults to None.
dtype: int64.
length_target (Variable, optional): shape(B, T_text), dtype int64,
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed. Defaults to 1.0.
mel, thereby controlling the voice speed. Defaults to 1.0.
Returns:
Returns:
mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C),
mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C).
mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
duration_predictor_output (Variable), the duration of phoneme compute with duration predictor.
duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor.
Shape: (B, T_text).
enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list.
enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers.
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers.
"""
"""
encoder_output
,
enc_slf_attn_list
=
self
.
encoder
(
encoder_output
,
enc_slf_attn_list
=
self
.
encoder
(
...
...
parakeet/models/fastspeech/fft_block.py
浏览文件 @
429695d6
...
@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer):
...
@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer):
d_inner
,
d_inner
,
n_head
,
n_head
,
d_k
,
d_k
,
d_
v
,
d_
q
,
filter_size
,
filter_size
,
padding
,
padding
,
dropout
=
0.2
):
dropout
=
0.2
):
"""Feed forward structure based on self-attention.
Args:
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
filter_size (int): the conv kernel size.
padding (int): the conv padding size.
dropout (float, optional): dropout probability. Defaults to 0.2.
"""
super
(
FFTBlock
,
self
).
__init__
()
super
(
FFTBlock
,
self
).
__init__
()
self
.
slf_attn
=
MultiheadAttention
(
self
.
slf_attn
=
MultiheadAttention
(
d_model
,
d_model
,
d_k
,
d_k
,
d_
v
,
d_
q
,
num_head
=
n_head
,
num_head
=
n_head
,
is_bias
=
True
,
is_bias
=
True
,
dropout
=
dropout
,
dropout
=
dropout
,
...
@@ -48,20 +60,18 @@ class FFTBlock(dg.Layer):
...
@@ -48,20 +60,18 @@ class FFTBlock(dg.Layer):
def
forward
(
self
,
enc_input
,
non_pad_mask
,
slf_attn_mask
=
None
):
def
forward
(
self
,
enc_input
,
non_pad_mask
,
slf_attn_mask
=
None
):
"""
"""
Feed
Forward Transformer block in FastSpeech.
Feed
forward block of FastSpeech
Args:
Args:
enc_input (Variable): The embedding characters input.
enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input,
Shape: (B, T, C), T means the timesteps of input, dtype: float32.
where T means the timesteps of input.
non_pad_mask (Variable): The mask of sequence.
non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
Shape: (B, T, 1), dtype: int64.
slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None.
Shape(B, len_q, len_k), len_q means the sequence length of query,
len_k means the sequence length of key, dtype: int64.
Returns:
Returns:
output (Variable)
, the output after self-attention & ffn. Shape: (B, T, C).
output (Variable)
: shape(B, T, C), the output after self-attention & ffn.
slf_attn (Variable)
, the self attention. Shape: (B * n_head, T, T),
slf_attn (Variable)
: shape(B * n_head, T, T), the self attention.
"""
"""
output
,
slf_attn
=
self
.
slf_attn
(
output
,
slf_attn
=
self
.
slf_attn
(
enc_input
,
enc_input
,
enc_input
,
mask
=
slf_attn_mask
)
enc_input
,
enc_input
,
enc_input
,
mask
=
slf_attn_mask
)
...
...
parakeet/models/fastspeech/length_regulator.py
浏览文件 @
429695d6
...
@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D
...
@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D
class
LengthRegulator
(
dg
.
Layer
):
class
LengthRegulator
(
dg
.
Layer
):
def
__init__
(
self
,
input_size
,
out_channels
,
filter_size
,
dropout
=
0.1
):
def
__init__
(
self
,
input_size
,
out_channels
,
filter_size
,
dropout
=
0.1
):
"""Length Regulator block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size of duration predictor.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super
(
LengthRegulator
,
self
).
__init__
()
super
(
LengthRegulator
,
self
).
__init__
()
self
.
duration_predictor
=
DurationPredictor
(
self
.
duration_predictor
=
DurationPredictor
(
input_size
=
input_size
,
input_size
=
input_size
,
...
@@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer):
...
@@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer):
def
forward
(
self
,
x
,
alpha
=
1.0
,
target
=
None
):
def
forward
(
self
,
x
,
alpha
=
1.0
,
target
=
None
):
"""
"""
Length Regulator block in FastSpeech.
Compute length of mel from encoder output use TransformerTTS attention
Args:
Args:
x (Variable): The encoder output.
x (Variable): shape(B, T, C), dtype float32, the encoder output.
Shape: (B, T, C), dtype: float32.
alpha (float32, optional): the hyperparameter to determine the length of
alpha (float32, optional): The hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
target (Variable, optional):
T
he duration of phoneme compute from pretrained transformerTTS.
target (Variable, optional):
shape(B, T_text), dtype int64, t
he duration of phoneme compute from pretrained transformerTTS.
Defaults to None.
Shape: (B, T_text), dtype: int64.
Defaults to None.
Returns:
Returns:
output (Variable), the output after exppand. Shape: (B, T, C),
output (Variable): shape(B, T, C), the output after exppand.
duration_predictor_output (Variable), the output of duration predictor.
duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor.
Shape: (B, T, C).
"""
"""
duration_predictor_output
=
self
.
duration_predictor
(
x
)
duration_predictor_output
=
self
.
duration_predictor
(
x
)
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
...
@@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer):
...
@@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer):
class
DurationPredictor
(
dg
.
Layer
):
class
DurationPredictor
(
dg
.
Layer
):
def
__init__
(
self
,
input_size
,
out_channels
,
filter_size
,
dropout
=
0.1
):
def
__init__
(
self
,
input_size
,
out_channels
,
filter_size
,
dropout
=
0.1
):
"""Duration Predictor block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super
(
DurationPredictor
,
self
).
__init__
()
super
(
DurationPredictor
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
input_size
=
input_size
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer):
...
@@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer):
def
forward
(
self
,
encoder_output
):
def
forward
(
self
,
encoder_output
):
"""
"""
Duration Predictor block in FastSpeech
.
Predict the duration of each character
.
Args:
Args:
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output.
encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output.
Returns:
Returns:
out (Variable)
, S
hape(B, T, C), the output of duration predictor.
out (Variable)
: s
hape(B, T, C), the output of duration predictor.
"""
"""
# encoder_output.shape(N, T, C)
# encoder_output.shape(N, T, C)
out
=
layers
.
transpose
(
encoder_output
,
[
0
,
2
,
1
])
out
=
layers
.
transpose
(
encoder_output
,
[
0
,
2
,
1
])
...
...
parakeet/models/transformer_tts/cbhg.py
浏览文件 @
429695d6
...
@@ -30,6 +30,17 @@ class CBHG(dg.Layer):
...
@@ -30,6 +30,17 @@ class CBHG(dg.Layer):
num_gru_layers
=
2
,
num_gru_layers
=
2
,
max_pool_kernel_size
=
2
,
max_pool_kernel_size
=
2
,
is_post
=
False
):
is_post
=
False
):
"""CBHG Module
Args:
hidden_size (int): dimension of hidden unit.
batch_size (int): batch size of input.
K (int, optional): number of convolution banks. Defaults to 16.
projection_size (int, optional): dimension of projection unit. Defaults to 256.
num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2.
max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2
is_post (bool, optional): whether post processing or not. Defaults to False.
"""
super
(
CBHG
,
self
).
__init__
()
super
(
CBHG
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -169,13 +180,13 @@ class CBHG(dg.Layer):
...
@@ -169,13 +180,13 @@ class CBHG(dg.Layer):
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
"""
"""
CBHG Module
Convert linear spectrum to Mel spectrum.
Args:
Args:
input_(Variable): The sequentially input.
input_ (Variable): shape(B, C, T), dtype float32, the sequentially input.
Shape: (B, C, T), dtype: float32.
Returns:
Returns:
(Variable):
the CBHG output.
out (Variable): shape(B, C, T),
the CBHG output.
"""
"""
conv_list
=
[]
conv_list
=
[]
...
@@ -217,6 +228,12 @@ class CBHG(dg.Layer):
...
@@ -217,6 +228,12 @@ class CBHG(dg.Layer):
class
Highwaynet
(
dg
.
Layer
):
class
Highwaynet
(
dg
.
Layer
):
def
__init__
(
self
,
num_units
,
num_layers
=
4
):
def
__init__
(
self
,
num_units
,
num_layers
=
4
):
"""Highway network
Args:
num_units (int): dimension of hidden unit.
num_layers (int, optional): number of highway layers. Defaults to 4.
"""
super
(
Highwaynet
,
self
).
__init__
()
super
(
Highwaynet
,
self
).
__init__
()
self
.
num_units
=
num_units
self
.
num_units
=
num_units
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
...
@@ -250,13 +267,13 @@ class Highwaynet(dg.Layer):
...
@@ -250,13 +267,13 @@ class Highwaynet(dg.Layer):
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
"""
"""
Highway network
Compute result of Highway network.
Args:
input_(Variable): The sequentially input.
Shape: (B, T, C), dtype: float32.
Args:
input_(Variable): shape(B, T, C), dtype float32, the sequentially input.
Returns:
Returns:
(Variable): the Highway output.
out
(Variable): the Highway output.
"""
"""
out
=
input_
out
=
input_
...
...
parakeet/models/transformer_tts/decoder.py
浏览文件 @
429695d6
...
@@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
...
@@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
class
Decoder
(
dg
.
Layer
):
class
Decoder
(
dg
.
Layer
):
def
__init__
(
self
,
num_hidden
,
config
,
num_head
=
4
,
n_layers
=
3
):
def
__init__
(
self
,
num_hidden
,
config
,
num_head
=
4
,
n_layers
=
3
):
"""Decoder layer of TransformerTTS.
Args:
num_hidden (int): the number of source vocabulary.
config: the yaml configs used in decoder.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super
(
Decoder
,
self
).
__init__
()
super
(
Decoder
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
self
.
num_head
=
num_head
...
@@ -109,38 +117,26 @@ class Decoder(dg.Layer):
...
@@ -109,38 +117,26 @@ class Decoder(dg.Layer):
m_self_mask
=
None
,
m_self_mask
=
None
,
zero_mask
=
None
):
zero_mask
=
None
):
"""
"""
Decoder layer of TransformerTTS.
Compute decoder outputs.
Args:
Args:
key (Variable): The input key of decoder.
key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder,
Shape: (B, T_text, C), T_text means the timesteps of input text,
where T_text means the timesteps of input text,
dtype: float32.
value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder.
value (Variable): The . input value of decoder.
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
Shape: (B, T_text, C), dtype: float32.
where T_mel means the timesteps of input spectrum,
query (Variable): The input query of decoder.
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
dtype: float32.
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
positional (Variable): The spectrum position.
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel), dtype: int64.
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
Returns:
mel_out (Variable): the decoder output after mel linear projection.
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
Shape: (B, T_mel, C).
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
out (Variable): the decoder output after post mel network.
stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output.
Shape: (B, T_mel, C).
attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list.
stop_tokens (Variable): the stop tokens of output.
selfattn_list (list[Variable]): len(n_layers), the decoder self attention list.
Shape: (B, T_mel, 1)
attn_list (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
selfattn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
"""
# get decoder mask with triangular matrix
# get decoder mask with triangular matrix
...
...
parakeet/models/transformer_tts/encoder.py
浏览文件 @
429695d6
...
@@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
...
@@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
class
Encoder
(
dg
.
Layer
):
class
Encoder
(
dg
.
Layer
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
num_head
=
4
,
n_layers
=
3
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
num_head
=
4
,
n_layers
=
3
):
"""Encoder layer of TransformerTTS.
Args:
embedding_size (int): the size of position embedding.
num_hidden (int): the size of hidden layer in network.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
self
.
num_head
=
num_head
...
@@ -58,23 +66,18 @@ class Encoder(dg.Layer):
...
@@ -58,23 +66,18 @@ class Encoder(dg.Layer):
def
forward
(
self
,
x
,
positional
,
mask
=
None
,
query_mask
=
None
):
def
forward
(
self
,
x
,
positional
,
mask
=
None
,
query_mask
=
None
):
"""
"""
Encoder layer of TransformerTTS.
Encode text sequence.
Args:
Args:
x (Variable): The input character.
x (Variable): shape(B, T_text), dtype float32, the input character,
Shape: (B, T_text), T_text means the timesteps of input text,
where T_text means the timesteps of input text,
dtype: float32.
positional (Variable): shape(B, T_text), dtype int64, the characters position.
positional (Variable): The characters position.
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
Shape: (B, T_text), dtype: int64.
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
Returns:
Returns:
x (Variable): the encoder output.
x (Variable): shape(B, T_text, C), the encoder output.
Shape: (B, T_text, C).
attentions (list[Variable]): len(n_layers), the encoder self attention list.
attentions (list[Variable]): the encoder self attention list.
Len: n_layers.
"""
"""
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
...
...
parakeet/models/transformer_tts/encoderprenet.py
浏览文件 @
429695d6
...
@@ -22,6 +22,13 @@ import numpy as np
...
@@ -22,6 +22,13 @@ import numpy as np
class
EncoderPrenet
(
dg
.
Layer
):
class
EncoderPrenet
(
dg
.
Layer
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
use_cudnn
=
True
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
use_cudnn
=
True
):
""" Encoder prenet layer of TransformerTTS.
Args:
embedding_size (int): the size of embedding.
num_hidden (int): the size of hidden layer in network.
use_cudnn (bool, optional): use cudnn or not. Defaults to True.
"""
super
(
EncoderPrenet
,
self
).
__init__
()
super
(
EncoderPrenet
,
self
).
__init__
()
self
.
embedding_size
=
embedding_size
self
.
embedding_size
=
embedding_size
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
...
@@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer):
...
@@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Encoder prenet layer of TransformerTTS.
Prepare encoder input.
Args:
Args:
x (Variable): The input character.
x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
Returns:
Returns:
(Variable):
the encoder prenet output. Shape: (B, T_text, C)
.
(Variable):
shape(B, T_text, C), the encoder prenet output
.
"""
"""
x
=
self
.
embedding
(
x
)
x
=
self
.
embedding
(
x
)
...
...
parakeet/models/transformer_tts/post_convnet.py
浏览文件 @
429695d6
...
@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer):
...
@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer):
use_cudnn
=
True
,
use_cudnn
=
True
,
dropout
=
0.1
,
dropout
=
0.1
,
batchnorm_last
=
False
):
batchnorm_last
=
False
):
"""Decocder post conv net of TransformerTTS.
Args:
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
num_hidden (int, optional): the size of hidden layer in network. Defaults to 512.
filter_size (int, optional): the filter size of Conv. Defaults to 5.
padding (int, optional): the padding size of Conv. Defaults to 0.
num_conv (int, optional): the num of Conv layers in network. Defaults to 5.
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False.
"""
super
(
PostConvNet
,
self
).
__init__
()
super
(
PostConvNet
,
self
).
__init__
()
self
.
dropout
=
dropout
self
.
dropout
=
dropout
...
@@ -93,13 +106,13 @@ class PostConvNet(dg.Layer):
...
@@ -93,13 +106,13 @@ class PostConvNet(dg.Layer):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
"""
"""
Decocder Post Conv Net of TransformerTTS
.
Compute the mel spectrum
.
Args:
Args:
input (Variable):
T
he result of mel linear projection.
input (Variable):
shape(B, T, C), dtype float32, t
he result of mel linear projection.
Shape: (B, T, C), dtype: float32.
Returns:
Returns:
(Variable): the result after postconvnet. Shape: (B, T, C),
output (Variable): shape(B, T, C), the result after postconvnet.
"""
"""
input
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
input
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
...
...
parakeet/models/transformer_tts/prenet.py
浏览文件 @
429695d6
...
@@ -19,6 +19,14 @@ import paddle.fluid.layers as layers
...
@@ -19,6 +19,14 @@ import paddle.fluid.layers as layers
class
PreNet
(
dg
.
Layer
):
class
PreNet
(
dg
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
output_size
,
dropout_rate
=
0.2
):
def
__init__
(
self
,
input_size
,
hidden_size
,
output_size
,
dropout_rate
=
0.2
):
"""Prenet before passing through the network.
Args:
input_size (int): the input channel size.
hidden_size (int): the size of hidden layer in network.
output_size (int): the output channel size.
dropout_rate (float, optional): dropout probability. Defaults to 0.2.
"""
super
(
PreNet
,
self
).
__init__
()
super
(
PreNet
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -44,20 +52,20 @@ class PreNet(dg.Layer):
...
@@ -44,20 +52,20 @@ class PreNet(dg.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Pre
Net before passing through the network
.
Pre
pare network input
.
Args:
Args:
x (Variable):
T
he input value.
x (Variable):
shape(B, T, C), dtype float32, t
he input value.
Shape: (B, T, C), dtype: float32.
Returns:
Returns:
(Variable), the result after pernet. Shape: (B, T, C),
output (Variable): shape(B, T, C), the result after pernet.
"""
"""
x
=
layers
.
dropout
(
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
linear1
(
x
)),
layers
.
relu
(
self
.
linear1
(
x
)),
self
.
dropout_rate
,
self
.
dropout_rate
,
dropout_implementation
=
'upscale_in_train'
)
dropout_implementation
=
'upscale_in_train'
)
x
=
layers
.
dropout
(
output
=
layers
.
dropout
(
layers
.
relu
(
self
.
linear2
(
x
)),
layers
.
relu
(
self
.
linear2
(
x
)),
self
.
dropout_rate
,
self
.
dropout_rate
,
dropout_implementation
=
'upscale_in_train'
)
dropout_implementation
=
'upscale_in_train'
)
return
x
return
output
parakeet/models/transformer_tts/transformer_tts.py
浏览文件 @
429695d6
...
@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder
...
@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder
class
TransformerTTS
(
dg
.
Layer
):
class
TransformerTTS
(
dg
.
Layer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""TransformerTTS model.
Args:
config: the yaml configs used in TransformerTTS model.
"""
super
(
TransformerTTS
,
self
).
__init__
()
super
(
TransformerTTS
,
self
).
__init__
()
self
.
encoder
=
Encoder
(
config
[
'embedding_size'
],
config
[
'hidden_size'
])
self
.
encoder
=
Encoder
(
config
[
'embedding_size'
],
config
[
'hidden_size'
])
self
.
decoder
=
Decoder
(
config
[
'hidden_size'
],
config
)
self
.
decoder
=
Decoder
(
config
[
'hidden_size'
],
config
)
...
@@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer):
...
@@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer):
dec_query_mask
=
None
):
dec_query_mask
=
None
):
"""
"""
TransformerTTS network.
TransformerTTS network.
Args:
Args:
characters (Variable): The input character.
characters (Variable): shape(B, T_text), dtype float32, the input character,
Shape: (B, T_text), T_text means the timesteps of input text,
where T_text means the timesteps of input text,
dtype: float32.
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
mel_input (Variable): The input query of decoder.
where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
dtype: float32.
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
pos_text (Variable): The characters position.
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
Shape: (B, T_text), dtype: int64.
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
dec_slf_mask (Variable): The spectrum position.
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
Shape: (B, T_mel), dtype: int64.
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
mask (Variable): the mask of decoder self attention.
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, T_mel), dtype: int64.
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
Returns:
mel_output (Variable): the decoder output after mel linear projection.
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
Shape: (B, T_mel, C).
postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network.
postnet_output (Variable): the decoder output after post mel network.
stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output.
Shape: (B, T_mel, C).
attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list.
stop_preds (Variable): the stop tokens of output.
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
Shape: (B, T_mel, 1)
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
attn_probs (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
attns_enc (list[Variable]): the encoder self attention list.
Len: n_layers.
attns_dec (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
"""
key
,
attns_enc
=
self
.
encoder
(
key
,
attns_enc
=
self
.
encoder
(
characters
,
pos_text
,
mask
=
enc_slf_mask
,
query_mask
=
enc_query_mask
)
characters
,
pos_text
,
mask
=
enc_slf_mask
,
query_mask
=
enc_query_mask
)
...
...
parakeet/models/transformer_tts/vocoder.py
浏览文件 @
429695d6
...
@@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG
...
@@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG
class
Vocoder
(
dg
.
Layer
):
class
Vocoder
(
dg
.
Layer
):
def
__init__
(
self
,
config
,
batch_size
):
def
__init__
(
self
,
config
,
batch_size
):
"""CBHG Network (mel -> linear)
Args:
config: the yaml configs used in Vocoder model.
batch_size (int): the batch size of input.
"""
super
(
Vocoder
,
self
).
__init__
()
super
(
Vocoder
,
self
).
__init__
()
self
.
pre_proj
=
Conv1D
(
self
.
pre_proj
=
Conv1D
(
num_channels
=
config
[
'audio'
][
'num_mels'
],
num_channels
=
config
[
'audio'
][
'num_mels'
],
...
@@ -33,14 +39,13 @@ class Vocoder(dg.Layer):
...
@@ -33,14 +39,13 @@ class Vocoder(dg.Layer):
def
forward
(
self
,
mel
):
def
forward
(
self
,
mel
):
"""
"""
CBHG Network (mel -> linear)
Compute mel spectrum to linear spectrum.
Args:
Args:
mel (Variable): The input mel spectrum.
mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum.
Shape: (B, C, T), dtype: float32.
Returns:
Returns:
(Variable): the linear output.
mag_pred (Variable): shape(B, T, C), the linear output.
Shape: (B, T, C).
"""
"""
mel
=
layers
.
transpose
(
mel
,
[
0
,
2
,
1
])
mel
=
layers
.
transpose
(
mel
,
[
0
,
2
,
1
])
mel
=
self
.
pre_proj
(
mel
)
mel
=
self
.
pre_proj
(
mel
)
...
...
parakeet/modules/dynamic_gru.py
浏览文件 @
429695d6
...
@@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer):
...
@@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer):
Dynamic GRU block.
Dynamic GRU block.
Args:
Args:
input (Variable):
T
he input value.
input (Variable):
shape(B, T, C), dtype float32, t
he input value.
Shape: (B, T, C), dtype: float32.
Returns:
Returns:
output (Variable)
, the result compute by GRU. Shape: (B, T, C)
.
output (Variable)
: shape(B, T, C), the result compute by GRU
.
"""
"""
hidden
=
self
.
h_0
hidden
=
self
.
h_0
res
=
[]
res
=
[]
...
...
parakeet/modules/ffn.py
浏览文件 @
429695d6
...
@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D
...
@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D
class
PositionwiseFeedForward
(
dg
.
Layer
):
class
PositionwiseFeedForward
(
dg
.
Layer
):
''' A two-feed-forward-layer module '''
def
__init__
(
self
,
def
__init__
(
self
,
d_in
,
d_in
,
num_hidden
,
num_hidden
,
...
@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer):
...
@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer):
padding
=
0
,
padding
=
0
,
use_cudnn
=
True
,
use_cudnn
=
True
,
dropout
=
0.1
):
dropout
=
0.1
):
"""A two-feed-forward-layer module.
Args:
d_in (int): the size of input channel.
num_hidden (int): the size of hidden layer in network.
filter_size (int): the filter size of Conv
padding (int, optional): the padding size of Conv. Defaults to 0.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super
(
PositionwiseFeedForward
,
self
).
__init__
()
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
self
.
use_cudnn
=
use_cudnn
self
.
use_cudnn
=
use_cudnn
...
@@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer):
...
@@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
"""
"""
Feed Forward Network
.
Compute feed forward network result
.
Args:
Args:
input (Variable):
T
he input value.
input (Variable):
shape(B, T, C), dtype float32, t
he input value.
Shape: (B, T, C), dtype: float32.
Returns:
Returns:
output (Variable)
, the result after FFN. Shape: (B, T, C).
output (Variable)
: shape(B, T, C), the result after FFN.
"""
"""
x
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
x
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
#FFN Networt
#FFN Networt
...
...
parakeet/modules/multihead_attention.py
浏览文件 @
429695d6
...
@@ -50,6 +50,11 @@ class Linear(dg.Layer):
...
@@ -50,6 +50,11 @@ class Linear(dg.Layer):
class
ScaledDotProductAttention
(
dg
.
Layer
):
class
ScaledDotProductAttention
(
dg
.
Layer
):
def
__init__
(
self
,
d_key
):
def
__init__
(
self
,
d_key
):
"""Scaled dot product attention module.
Args:
d_key (int): the dim of key in multihead attention.
"""
super
(
ScaledDotProductAttention
,
self
).
__init__
()
super
(
ScaledDotProductAttention
,
self
).
__init__
()
self
.
d_key
=
d_key
self
.
d_key
=
d_key
...
@@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer):
...
@@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer):
query_mask
=
None
,
query_mask
=
None
,
dropout
=
0.1
):
dropout
=
0.1
):
"""
"""
Scaled Dot Product A
ttention.
Compute scaled dot product a
ttention.
Args:
Args:
key (Variable): The input key of scaled dot product attention.
key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
value (Variable): The input value of scaled dot product attention.
query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
query (Variable): The input query of scaled dot product attention.
query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
Shape: (B, T, C), dtype: float32.
dropout (float32, optional): the probability of dropout. Defaults to 0.1.
mask (Variable, optional): The mask of key. Defaults to None.
Shape(B, T_q, T_k), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape(B, T_q, T_q), dtype: float32.
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
Returns:
Returns:
result (Variable)
, S
hape(B, T, C), the result of mutihead attention.
result (Variable)
: s
hape(B, T, C), the result of mutihead attention.
attention (Variable)
, S
hape(n_head * B, T, C), the attention of key.
attention (Variable)
: s
hape(n_head * B, T, C), the attention of key.
"""
"""
# Compute attention score
# Compute attention score
attention
=
layers
.
matmul
(
attention
=
layers
.
matmul
(
...
@@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer):
...
@@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer):
is_bias
=
False
,
is_bias
=
False
,
dropout
=
0.1
,
dropout
=
0.1
,
is_concat
=
True
):
is_concat
=
True
):
"""Multihead Attention.
Args:
num_hidden (int): the number of hidden layer in network.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
num_head (int, optional): the head number of multihead attention. Defaults to 4.
is_bias (bool, optional): whether have bias in linear layers. Default to False.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
is_concat (bool, optional): whether concat query and result. Default to True.
"""
super
(
MultiheadAttention
,
self
).
__init__
()
super
(
MultiheadAttention
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
self
.
num_head
=
num_head
...
@@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer):
...
@@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer):
def
forward
(
self
,
key
,
value
,
query_input
,
mask
=
None
,
query_mask
=
None
):
def
forward
(
self
,
key
,
value
,
query_input
,
mask
=
None
,
query_mask
=
None
):
"""
"""
Multihead A
ttention.
Compute a
ttention.
Args:
Args:
key (Variable): The input key of attention.
key (Variable): shape(B, T, C), dtype float32, the input key of attention.
Shape: (B, T, C), dtype: float32.
value (Variable): shape(B, T, C), dtype float32, the input value of attention.
value (Variable): The input value of attention.
query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
query_input (Variable): The input query of attention.
query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
Returns:
Returns:
result (Variable)
, the result of mutihead attention. Shape: (B, T, C).
result (Variable)
: shape(B, T, C), the result of mutihead attention.
attention (Variable)
, the attention of key and query. Shape: (num_head * B, T, C)
attention (Variable)
: shape(num_head * B, T, C), the attention of key and query.
"""
"""
batch_size
=
key
.
shape
[
0
]
batch_size
=
key
.
shape
[
0
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录