Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Parakeet
提交
f7ec215b
P
Parakeet
项目概览
PaddlePaddle
/
Parakeet
通知
11
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Parakeet
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f7ec215b
编写于
3月 09, 2020
作者:
L
lifuchen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add docstring for transformer_tts and fastspeech
上级
a302bf21
变更
23
展开全部
隐藏空白更改
内联
并排
Showing
23 changed file
with
271 addition
and
738 deletion
+271
-738
examples/fastspeech/synthesis.sh
examples/fastspeech/synthesis.sh
+1
-0
examples/fastspeech/train.sh
examples/fastspeech/train.sh
+1
-1
examples/transformer_tts/synthesis.py
examples/transformer_tts/synthesis.py
+3
-4
examples/transformer_tts/synthesis.sh
examples/transformer_tts/synthesis.sh
+4
-4
examples/transformer_tts/train_transformer.sh
examples/transformer_tts/train_transformer.sh
+1
-1
examples/transformer_tts/train_vocoder.sh
examples/transformer_tts/train_vocoder.sh
+1
-1
parakeet/models/fastspeech/decoder.py
parakeet/models/fastspeech/decoder.py
+17
-7
parakeet/models/fastspeech/encoder.py
parakeet/models/fastspeech/encoder.py
+16
-9
parakeet/models/fastspeech/fastspeech.py
parakeet/models/fastspeech/fastspeech.py
+29
-19
parakeet/models/fastspeech/fft_block.py
parakeet/models/fastspeech/fft_block.py
+10
-8
parakeet/models/fastspeech/length_regulator.py
parakeet/models/fastspeech/length_regulator.py
+9
-7
parakeet/models/transformer_tts/cbhg.py
parakeet/models/transformer_tts/cbhg.py
+19
-10
parakeet/models/transformer_tts/decoder.py
parakeet/models/transformer_tts/decoder.py
+39
-5
parakeet/models/transformer_tts/encoder.py
parakeet/models/transformer_tts/encoder.py
+24
-5
parakeet/models/transformer_tts/encoderprenet.py
parakeet/models/transformer_tts/encoderprenet.py
+11
-1
parakeet/models/transformer_tts/post_convnet.py
parakeet/models/transformer_tts/post_convnet.py
+4
-3
parakeet/models/transformer_tts/prenet.py
parakeet/models/transformer_tts/prenet.py
+3
-7
parakeet/models/transformer_tts/transformer_tts.py
parakeet/models/transformer_tts/transformer_tts.py
+40
-2
parakeet/models/transformer_tts/vocoder.py
parakeet/models/transformer_tts/vocoder.py
+10
-4
parakeet/modules/dynamic_gru.py
parakeet/modules/dynamic_gru.py
+3
-2
parakeet/modules/ffn.py
parakeet/modules/ffn.py
+3
-2
parakeet/modules/modules.py
parakeet/modules/modules.py
+0
-610
parakeet/modules/multihead_attention.py
parakeet/modules/multihead_attention.py
+23
-26
未找到文件。
examples/fastspeech/synthesis.sh
浏览文件 @
f7ec215b
...
...
@@ -6,6 +6,7 @@ python -u synthesis.py \
--checkpoint_path
=
'checkpoint/'
\
--fastspeech_step
=
71000
\
--log_dir
=
'./log'
\
--config_path
=
'configs/synthesis.yaml'
\
if
[
$?
-ne
0
]
;
then
echo
"Failed in synthesis!"
...
...
examples/fastspeech/train.sh
浏览文件 @
f7ec215b
...
...
@@ -13,7 +13,7 @@ python -u train.py \
--transformer_step
=
160000
\
--save_path
=
'./checkpoint'
\
--log_dir
=
'./log'
\
--config_path
=
'config/fastspeech.yaml'
\
--config_path
=
'config
s
/fastspeech.yaml'
\
#--checkpoint_path='./checkpoint' \
#--fastspeech_step=97000 \
...
...
examples/transformer_tts/synthesis.py
浏览文件 @
f7ec215b
...
...
@@ -84,7 +84,7 @@ def synthesis(text_input, args):
dec_slf_mask
=
get_triu_tensor
(
mel_input
.
numpy
(),
mel_input
.
numpy
()).
astype
(
np
.
float32
)
dec_slf_mask
=
fluid
.
layers
.
cast
(
dg
.
to_variable
(
dec_slf_mask
==
0
),
np
.
float32
)
dg
.
to_variable
(
dec_slf_mask
!=
0
),
np
.
float32
)
*
(
-
2
**
32
+
1
)
pos_mel
=
np
.
arange
(
1
,
mel_input
.
shape
[
1
]
+
1
)
pos_mel
=
fluid
.
layers
.
unsqueeze
(
dg
.
to_variable
(
pos_mel
),
[
0
])
mel_pred
,
postnet_pred
,
attn_probs
,
stop_preds
,
attn_enc
,
attn_dec
=
model
(
...
...
@@ -157,6 +157,5 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
(
description
=
"Synthesis model"
)
add_config_options_to_parser
(
parser
)
args
=
parser
.
parse_args
()
synthesis
(
"They emphasized the necessity that the information now being furnished be handled with judgment and care."
,
args
)
synthesis
(
"Parakeet stands for Paddle PARAllel text-to-speech toolkit."
,
args
)
examples/transformer_tts/synthesis.sh
浏览文件 @
f7ec215b
...
...
@@ -2,14 +2,14 @@
# train model
CUDA_VISIBLE_DEVICES
=
0
\
python
-u
synthesis.py
\
--max_len
=
6
00
\
--transformer_step
=
1
6
0000
\
--vocoder_step
=
9
0000
\
--max_len
=
3
00
\
--transformer_step
=
1
2
0000
\
--vocoder_step
=
10
0000
\
--use_gpu
=
1
\
--checkpoint_path
=
'./checkpoint'
\
--log_dir
=
'./log'
\
--sample_path
=
'./sample'
\
--config_path
=
'config/synthesis.yaml'
\
--config_path
=
'config
s
/synthesis.yaml'
\
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
examples/transformer_tts/train_transformer.sh
浏览文件 @
f7ec215b
...
...
@@ -14,7 +14,7 @@ python -u train_transformer.py \
--data_path
=
'../../dataset/LJSpeech-1.1'
\
--save_path
=
'./checkpoint'
\
--log_dir
=
'./log'
\
--config_path
=
'config/train_transformer.yaml'
\
--config_path
=
'config
s
/train_transformer.yaml'
\
#--checkpoint_path='./checkpoint' \
#--transformer_step=160000 \
...
...
examples/transformer_tts/train_vocoder.sh
浏览文件 @
f7ec215b
...
...
@@ -12,7 +12,7 @@ python -u train_vocoder.py \
--data_path
=
'../../dataset/LJSpeech-1.1'
\
--save_path
=
'./checkpoint'
\
--log_dir
=
'./log'
\
--config_path
=
'config/train_vocoder.yaml'
\
--config_path
=
'config
s
/train_vocoder.yaml'
\
#--checkpoint_path='./checkpoint' \
#--vocoder_step=27000 \
...
...
parakeet/models/fastspeech/decoder.py
浏览文件 @
f7ec215b
...
...
@@ -59,15 +59,25 @@ class Decoder(dg.Layer):
def
forward
(
self
,
enc_seq
,
enc_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
"""
Decoder layer of FastSpeech.
Args:
enc_seq (Variable), Shape(B, text_T, C), dtype: float32.
The output of length regulator.
enc_pos (Variable, optional): Shape(B, T_mel),
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
enc_seq (Variable): The output of length regulator.
Shape: (B, T_text, C), T_text means the timesteps of input text,
dtype: float32.
enc_pos (Variable): The spectrum position.
Shape: (B, T_mel), T_mel means the timesteps of input spectrum,
dtype: int64.
non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_mel, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
Returns:
dec_output (Variable), Shape(B, mel_T, C), the decoder output.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
dec_output (Variable): the decoder output.
Shape: (B, T_mel, C).
dec_slf_attn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
dec_slf_attn_list
=
[]
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
...
...
parakeet/models/fastspeech/encoder.py
浏览文件 @
f7ec215b
...
...
@@ -64,17 +64,24 @@ class Encoder(dg.Layer):
def
forward
(
self
,
character
,
text_pos
,
non_pad_mask
,
slf_attn_mask
=
None
):
"""
Encoder layer of FastSpeech.
Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text
characters. T_text means the timesteps of input characters.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
position. T_text means the timesteps of input characters.
character (Variable): The input text characters.
Shape: (B, T_text), T_text means the timesteps of input characters,
dtype: float32.
text_pos (Variable): The input text position.
Shape: (B, T_text), dtype: int64.
non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_text, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
Returns:
enc_output (Variable), Shape(B, text_T, C), the encoder output.
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
enc_output (Variable), the encoder output. Shape(B, T_text, C)
non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1)
enc_slf_attn_list (list[Variable]), the encoder self attention list.
Len: n_layers.
"""
enc_slf_attn_list
=
[]
slf_attn_mask
=
layers
.
expand
(
slf_attn_mask
,
[
self
.
n_head
,
1
,
1
])
...
...
parakeet/models/fastspeech/fastspeech.py
浏览文件 @
f7ec215b
...
...
@@ -82,34 +82,45 @@ class FastSpeech(dg.Layer):
text_pos
,
enc_non_pad_mask
,
dec_non_pad_mask
,
mel_pos
=
None
,
enc_slf_attn_mask
=
None
,
dec_slf_attn_mask
=
None
,
mel_pos
=
None
,
length_target
=
None
,
alpha
=
1.0
):
"""
FastSpeech model.
Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text
characters. T_text means the timesteps of input characters.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
position. T_text means the timesteps of input characters.
mel_pos (Variable, optional): Shape(B, T_mel),
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
length_target (Variable, optional): Shape(B, T_text),
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
alpha (Constant):
dtype: float32. The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed.
character (Variable): The input text characters.
Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32.
text_pos (Variable): The input text position.
Shape: (B, T_text), dtype: int64.
mel_pos (Variable, optional): The spectrum position.
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64.
enc_non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_text, 1),
dtype: int64.
dec_non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_mel, 1),
dtype: int64.
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed. Defaults to 1.0.
Returns:
mel_output (Variable),
Shape(B, mel_T, C), the mel output before postnet.
mel_output_postnet (Variable),
Shape(B, mel_T, C), the mel output after postnet
.
duration_predictor_output (Variable),
Shape(B, text_T), the duration of phoneme compute
with duration predictor
.
enc_slf_attn_list (
Variable), Shape(B, text_T, text_T), the encoder self attention list
.
dec_slf_attn_list (
Variable), Shape(B, mel_T, mel_T), the decoder self attention list
.
mel_output (Variable),
the mel output before postnet. Shape: (B, T_mel, C),
mel_output_postnet (Variable),
the mel output after postnet. Shape: (B, T_mel, C)
.
duration_predictor_output (Variable),
the duration of phoneme compute with duration predictor.
Shape: (B, T_text)
.
enc_slf_attn_list (
List[Variable]), the encoder self attention list. Len: enc_n_layers
.
dec_slf_attn_list (
List[Variable]), the decoder self attention list. Len: dec_n_layers
.
"""
encoder_output
,
enc_slf_attn_list
=
self
.
encoder
(
...
...
@@ -118,7 +129,6 @@ class FastSpeech(dg.Layer):
enc_non_pad_mask
,
slf_attn_mask
=
enc_slf_attn_mask
)
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
length_regulator_output
,
duration_predictor_output
=
self
.
length_regulator
(
encoder_output
,
target
=
length_target
,
alpha
=
alpha
)
decoder_output
,
dec_slf_attn_list
=
self
.
decoder
(
...
...
parakeet/models/fastspeech/fft_block.py
浏览文件 @
f7ec215b
...
...
@@ -51,15 +51,17 @@ class FFTBlock(dg.Layer):
Feed Forward Transformer block in FastSpeech.
Args:
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input.
T means the timesteps of input.
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence.
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention.
len_q means the sequence length of query, len_k means the sequence length of key.
enc_input (Variable): The embedding characters input.
Shape: (B, T, C), T means the timesteps of input, dtype: float32.
non_pad_mask (Variable): The mask of sequence.
Shape: (B, T, 1), dtype: int64.
slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
Shape(B, len_q, len_k), len_q means the sequence length of query,
len_k means the sequence length of key, dtype: int64.
Returns:
output (Variable),
Shape(B, T, C), the output after self-attention & ffn
.
slf_attn (Variable),
Shape(B * n_head, T, T), the self attention.
output (Variable),
the output after self-attention & ffn. Shape: (B, T, C)
.
slf_attn (Variable),
the self attention. Shape: (B * n_head, T, T),
"""
output
,
slf_attn
=
self
.
slf_attn
(
enc_input
,
enc_input
,
enc_input
,
mask
=
slf_attn_mask
)
...
...
parakeet/models/fastspeech/length_regulator.py
浏览文件 @
f7ec215b
...
...
@@ -69,15 +69,17 @@ class LengthRegulator(dg.Layer):
Length Regulator block in FastSpeech.
Args:
x (Variable): Shape(B, T, C), dtype: float32. The encoder output.
alpha (Constant): dtype: float32. The hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed.
target (Variable): (Variable, optional): Shape(B, T_text),
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
x (Variable): The encoder output.
Shape: (B, T, C), dtype: float32.
alpha (float32, optional): The hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
Returns:
output (Variable), Shape(B, T, C), the output after exppand.
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor.
output (Variable), the output after exppand. Shape: (B, T, C),
duration_predictor_output (Variable), the output of duration predictor.
Shape: (B, T, C).
"""
duration_predictor_output
=
self
.
duration_predictor
(
x
)
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
...
...
parakeet/models/transformer_tts/cbhg.py
浏览文件 @
f7ec215b
...
...
@@ -31,15 +31,7 @@ class CBHG(dg.Layer):
max_pool_kernel_size
=
2
,
is_post
=
False
):
super
(
CBHG
,
self
).
__init__
()
"""
:param hidden_size: dimension of hidden unit
:param batch_size: batch size
:param K: # of convolution banks
:param projection_size: dimension of projection unit
:param num_gru_layers: # of layers of GRUcell
:param max_pool_kernel_size: max pooling kernel size
:param is_post: whether post processing or not
"""
self
.
hidden_size
=
hidden_size
self
.
projection_size
=
projection_size
self
.
conv_list
=
[]
...
...
@@ -176,7 +168,15 @@ class CBHG(dg.Layer):
return
x
def
forward
(
self
,
input_
):
# input_.shape = [N, C, T]
"""
CBHG Module
Args:
input_(Variable): The sequentially input.
Shape: (B, C, T), dtype: float32.
Returns:
(Variable): the CBHG output.
"""
conv_list
=
[]
conv_input
=
input_
...
...
@@ -249,6 +249,15 @@ class Highwaynet(dg.Layer):
self
.
add_sublayer
(
"gates_{}"
.
format
(
i
),
gate
)
def
forward
(
self
,
input_
):
"""
Highway network
Args:
input_(Variable): The sequentially input.
Shape: (B, T, C), dtype: float32.
Returns:
(Variable): the Highway output.
"""
out
=
input_
for
linear
,
gate
in
zip
(
self
.
linears
,
self
.
gates
):
...
...
parakeet/models/transformer_tts/decoder.py
浏览文件 @
f7ec215b
...
...
@@ -22,7 +22,7 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
class
Decoder
(
dg
.
Layer
):
def
__init__
(
self
,
num_hidden
,
config
,
num_head
=
4
):
def
__init__
(
self
,
num_hidden
,
config
,
num_head
=
4
,
n_layers
=
3
):
super
(
Decoder
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
...
...
@@ -58,20 +58,20 @@ class Decoder(dg.Layer):
self
.
selfattn_layers
=
[
MultiheadAttention
(
num_hidden
,
num_hidden
//
num_head
,
num_hidden
//
num_head
)
for
_
in
range
(
3
)
num_hidden
//
num_head
)
for
_
in
range
(
n_layers
)
]
for
i
,
layer
in
enumerate
(
self
.
selfattn_layers
):
self
.
add_sublayer
(
"self_attn_{}"
.
format
(
i
),
layer
)
self
.
attn_layers
=
[
MultiheadAttention
(
num_hidden
,
num_hidden
//
num_head
,
num_hidden
//
num_head
)
for
_
in
range
(
3
)
num_hidden
//
num_head
)
for
_
in
range
(
n_layers
)
]
for
i
,
layer
in
enumerate
(
self
.
attn_layers
):
self
.
add_sublayer
(
"attn_{}"
.
format
(
i
),
layer
)
self
.
ffns
=
[
PositionwiseFeedForward
(
num_hidden
,
num_hidden
*
num_head
,
filter_size
=
1
)
for
_
in
range
(
3
)
for
_
in
range
(
n_layers
)
]
for
i
,
layer
in
enumerate
(
self
.
ffns
):
self
.
add_sublayer
(
"ffns_{}"
.
format
(
i
),
layer
)
...
...
@@ -108,6 +108,40 @@ class Decoder(dg.Layer):
m_mask
=
None
,
m_self_mask
=
None
,
zero_mask
=
None
):
"""
Decoder layer of TransformerTTS.
Args:
key (Variable): The input key of decoder.
Shape: (B, T_text, C), T_text means the timesteps of input text,
dtype: float32.
value (Variable): The . input value of decoder.
Shape: (B, T_text, C), dtype: float32.
query (Variable): The input query of decoder.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
dtype: float32.
positional (Variable): The spectrum position.
Shape: (B, T_mel), dtype: int64.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
mel_out (Variable): the decoder output after mel linear projection.
Shape: (B, T_mel, C).
out (Variable): the decoder output after post mel network.
Shape: (B, T_mel, C).
stop_tokens (Variable): the stop tokens of output.
Shape: (B, T_mel, 1)
attn_list (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
selfattn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
# get decoder mask with triangular matrix
...
...
@@ -121,7 +155,7 @@ class Decoder(dg.Layer):
else
:
m_mask
,
m_self_mask
,
zero_mask
=
None
,
None
,
None
# Decoder pre-network
# Decoder pre-network
query
=
self
.
decoder_prenet
(
query
)
# Centered position
...
...
parakeet/models/transformer_tts/encoder.py
浏览文件 @
f7ec215b
...
...
@@ -20,7 +20,7 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
class
Encoder
(
dg
.
Layer
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
num_head
=
4
):
def
__init__
(
self
,
embedding_size
,
num_hidden
,
num_head
=
4
,
n_layers
=
3
):
super
(
Encoder
,
self
).
__init__
()
self
.
num_hidden
=
num_hidden
self
.
num_head
=
num_head
...
...
@@ -42,7 +42,7 @@ class Encoder(dg.Layer):
use_cudnn
=
True
)
self
.
layers
=
[
MultiheadAttention
(
num_hidden
,
num_hidden
//
num_head
,
num_hidden
//
num_head
)
for
_
in
range
(
3
)
num_hidden
//
num_head
)
for
_
in
range
(
n_layers
)
]
for
i
,
layer
in
enumerate
(
self
.
layers
):
self
.
add_sublayer
(
"self_attn_{}"
.
format
(
i
),
layer
)
...
...
@@ -51,12 +51,31 @@ class Encoder(dg.Layer):
num_hidden
,
num_hidden
*
num_head
,
filter_size
=
1
,
use_cudnn
=
True
)
for
_
in
range
(
3
)
use_cudnn
=
True
)
for
_
in
range
(
n_layers
)
]
for
i
,
layer
in
enumerate
(
self
.
ffns
):
self
.
add_sublayer
(
"ffns_{}"
.
format
(
i
),
layer
)
def
forward
(
self
,
x
,
positional
,
mask
=
None
,
query_mask
=
None
):
"""
Encoder layer of TransformerTTS.
Args:
x (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
positional (Variable): The characters position.
Shape: (B, T_text), dtype: int64.
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
Returns:
x (Variable): the encoder output.
Shape: (B, T_text, C).
attentions (list[Variable]): the encoder self attention list.
Len: n_layers.
"""
if
fluid
.
framework
.
_dygraph_tracer
().
_train_mode
:
seq_len_key
=
x
.
shape
[
1
]
...
...
@@ -66,12 +85,12 @@ class Encoder(dg.Layer):
else
:
query_mask
,
mask
=
None
,
None
# Encoder pre_network
x
=
self
.
encoder_prenet
(
x
)
#(N,T,C)
x
=
self
.
encoder_prenet
(
x
)
# Get positional encoding
positional
=
self
.
pos_emb
(
positional
)
x
=
positional
*
self
.
alpha
+
x
#(N, T, C)
x
=
positional
*
self
.
alpha
+
x
# Positional dropout
x
=
layers
.
dropout
(
x
,
0.1
,
dropout_implementation
=
'upscale_in_train'
)
...
...
parakeet/models/transformer_tts/encoderprenet.py
浏览文件 @
f7ec215b
...
...
@@ -81,8 +81,18 @@ class EncoderPrenet(dg.Layer):
low
=-
k
,
high
=
k
)))
def
forward
(
self
,
x
):
"""
Encoder prenet layer of TransformerTTS.
Args:
x (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
Returns:
(Variable): the encoder prenet output. Shape: (B, T_text, C).
"""
x
=
self
.
embedding
(
x
)
#(batch_size, seq_len, embending_size)
x
=
self
.
embedding
(
x
)
x
=
layers
.
transpose
(
x
,
[
0
,
2
,
1
])
for
batch_norm
,
conv
in
zip
(
self
.
batch_norm_list
,
self
.
conv_list
):
x
=
layers
.
dropout
(
...
...
parakeet/models/transformer_tts/post_convnet.py
浏览文件 @
f7ec215b
...
...
@@ -93,12 +93,13 @@ class PostConvNet(dg.Layer):
def
forward
(
self
,
input
):
"""
Post Conv Net
.
Decocder Post Conv Net of TransformerTTS
.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
input (Variable): The result of mel linear projection.
Shape: (B, T, C), dtype: float32.
Returns:
output (Variable), Shape(B, T, C), the result after postconvnet.
(Variable): the result after postconvnet. Shape: (B, T, C),
"""
input
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
...
...
parakeet/models/transformer_tts/prenet.py
浏览文件 @
f7ec215b
...
...
@@ -19,11 +19,6 @@ import paddle.fluid.layers as layers
class
PreNet
(
dg
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
output_size
,
dropout_rate
=
0.2
):
"""
:param input_size: dimension of input
:param hidden_size: dimension of hidden unit
:param output_size: dimension of output
"""
super
(
PreNet
,
self
).
__init__
()
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
...
...
@@ -52,9 +47,10 @@ class PreNet(dg.Layer):
Pre Net before passing through the network.
Args:
x (Variable): Shape(B, T, C), dtype: float32. The input value.
x (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns:
x (Variable), Shape(B, T, C), the result after pernet.
(Variable), the result after pernet. Shape: (B, T, C),
"""
x
=
layers
.
dropout
(
layers
.
relu
(
self
.
linear1
(
x
)),
...
...
parakeet/models/transformer_tts/transformer_tts.py
浏览文件 @
f7ec215b
...
...
@@ -35,6 +35,46 @@ class TransformerTTS(dg.Layer):
enc_dec_mask
=
None
,
dec_query_slf_mask
=
None
,
dec_query_mask
=
None
):
"""
TransformerTTS network.
Args:
characters (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
mel_input (Variable): The input query of decoder.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
dtype: float32.
pos_text (Variable): The characters position.
Shape: (B, T_text), dtype: int64.
dec_slf_mask (Variable): The spectrum position.
Shape: (B, T_mel), dtype: int64.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
mel_output (Variable): the decoder output after mel linear projection.
Shape: (B, T_mel, C).
postnet_output (Variable): the decoder output after post mel network.
Shape: (B, T_mel, C).
stop_preds (Variable): the stop tokens of output.
Shape: (B, T_mel, 1)
attn_probs (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
attns_enc (list[Variable]): the encoder self attention list.
Len: n_layers.
attns_dec (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
key
,
attns_enc
=
self
.
encoder
(
characters
,
pos_text
,
mask
=
enc_slf_mask
,
query_mask
=
enc_query_mask
)
...
...
@@ -48,5 +88,3 @@ class TransformerTTS(dg.Layer):
m_self_mask
=
dec_query_slf_mask
,
m_mask
=
dec_query_mask
)
return
mel_output
,
postnet_output
,
attn_probs
,
stop_preds
,
attns_enc
,
attns_dec
return
mel_output
,
postnet_output
,
attn_probs
,
stop_preds
,
attns_enc
,
attns_dec
parakeet/models/transformer_tts/vocoder.py
浏览文件 @
f7ec215b
...
...
@@ -19,10 +19,6 @@ from parakeet.models.transformer_tts.cbhg import CBHG
class
Vocoder
(
dg
.
Layer
):
"""
CBHG Network (mel -> linear)
"""
def
__init__
(
self
,
config
,
batch_size
):
super
(
Vocoder
,
self
).
__init__
()
self
.
pre_proj
=
Conv1D
(
...
...
@@ -36,6 +32,16 @@ class Vocoder(dg.Layer):
filter_size
=
1
)
def
forward
(
self
,
mel
):
"""
CBHG Network (mel -> linear)
Args:
mel (Variable): The input mel spectrum.
Shape: (B, C, T), dtype: float32.
Returns:
(Variable): the linear output.
Shape: (B, T, C).
"""
mel
=
layers
.
transpose
(
mel
,
[
0
,
2
,
1
])
mel
=
self
.
pre_proj
(
mel
)
mel
=
self
.
cbhg
(
mel
)
...
...
parakeet/modules/dynamic_gru.py
浏览文件 @
f7ec215b
...
...
@@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer):
Dynamic GRU block.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
input (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns:
output (Variable),
Shape(B, T, C), the result compute by GRU
.
output (Variable),
the result compute by GRU. Shape: (B, T, C)
.
"""
hidden
=
self
.
h_0
res
=
[]
...
...
parakeet/modules/ffn.py
浏览文件 @
f7ec215b
...
...
@@ -62,9 +62,10 @@ class PositionwiseFeedForward(dg.Layer):
Feed Forward Network.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
input (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns:
output (Variable),
Shape(B, T, C), the result after FFN
.
output (Variable),
the result after FFN. Shape: (B, T, C)
.
"""
x
=
layers
.
transpose
(
input
,
[
0
,
2
,
1
])
#FFN Networt
...
...
parakeet/modules/modules.py
已删除
100644 → 0
浏览文件 @
a302bf21
此差异已折叠。
点击以展开。
parakeet/modules/multihead_attention.py
浏览文件 @
f7ec215b
...
...
@@ -66,12 +66,17 @@ class ScaledDotProductAttention(dg.Layer):
Scaled Dot Product Attention.
Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
dropout (Constant): dtype: float32. The probability of dropout.
key (Variable): The input key of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
value (Variable): The input value of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
query (Variable): The input query of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape(B, T_q, T_k), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape(B, T_q, T_q), dtype: float32.
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key.
...
...
@@ -131,14 +136,19 @@ class MultiheadAttention(dg.Layer):
Multihead Attention.
Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
key (Variable): The input key of attention.
Shape: (B, T, C), dtype: float32.
value (Variable): The input value of attention.
Shape: (B, T, C), dtype: float32.
query_input (Variable): The input query of attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
Returns:
result (Variable),
Shape(B, T, C), the result of mutihead attention
.
attention (Variable),
Shape(n_head * B, T, C), the attention of key.
result (Variable),
the result of mutihead attention. Shape: (B, T, C)
.
attention (Variable),
the attention of key and query. Shape: (num_head * B, T, C)
"""
batch_size
=
key
.
shape
[
0
]
...
...
@@ -146,7 +156,6 @@ class MultiheadAttention(dg.Layer):
seq_len_query
=
query_input
.
shape
[
1
]
# Make multihead attention
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
key
=
layers
.
reshape
(
self
.
key
(
key
),
[
batch_size
,
seq_len_key
,
self
.
num_head
,
self
.
d_k
])
value
=
layers
.
reshape
(
...
...
@@ -168,18 +177,6 @@ class MultiheadAttention(dg.Layer):
result
,
attention
=
self
.
scal_attn
(
key
,
value
,
query
,
mask
=
mask
,
query_mask
=
query_mask
)
key
=
layers
.
reshape
(
layers
.
transpose
(
key
,
[
2
,
0
,
1
,
3
]),
[
-
1
,
seq_len_key
,
self
.
d_k
])
value
=
layers
.
reshape
(
layers
.
transpose
(
value
,
[
2
,
0
,
1
,
3
]),
[
-
1
,
seq_len_key
,
self
.
d_k
])
query
=
layers
.
reshape
(
layers
.
transpose
(
query
,
[
2
,
0
,
1
,
3
]),
[
-
1
,
seq_len_query
,
self
.
d_q
])
result
,
attention
=
self
.
scal_attn
(
key
,
value
,
query
,
mask
=
mask
,
query_mask
=
query_mask
)
# concat all multihead result
result
=
layers
.
reshape
(
result
,
[
self
.
num_head
,
batch_size
,
seq_len_query
,
self
.
d_q
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录