提交 429695d6 编写于 作者: L lifuchen

add docstring to transformer_tts and fastspeech

上级 3d1fda0c
...@@ -23,12 +23,26 @@ class Decoder(dg.Layer): ...@@ -23,12 +23,26 @@ class Decoder(dg.Layer):
n_layers, n_layers,
n_head, n_head,
d_k, d_k,
d_v, d_q,
d_model, d_model,
d_inner, d_inner,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=0.1): dropout=0.1):
"""Decoder layer of FastSpeech.
Args:
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super(Decoder, self).__init__() super(Decoder, self).__init__()
n_position = len_max_seq + 1 n_position = len_max_seq + 1
...@@ -48,7 +62,7 @@ class Decoder(dg.Layer): ...@@ -48,7 +62,7 @@ class Decoder(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=dropout) for _ in range(n_layers) dropout=dropout) for _ in range(n_layers)
...@@ -58,26 +72,20 @@ class Decoder(dg.Layer): ...@@ -58,26 +72,20 @@ class Decoder(dg.Layer):
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
""" """
Decoder layer of FastSpeech. Compute decoder outputs.
Args: Args:
enc_seq (Variable): The output of length regulator. enc_seq (Variable): shape(B, T_text, C), dtype float32,
Shape: (B, T_text, C), T_text means the timesteps of input text, the output of length regulator, where T_text means the timesteps of input text,
dtype: float32. enc_pos (Variable): shape(B, T_mel), dtype int64,
enc_pos (Variable): The spectrum position. the spectrum position, where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
dtype: int64. slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
non_pad_mask (Variable): the mask with non pad. the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
Returns: Returns:
dec_output (Variable): the decoder output. dec_output (Variable): shape(B, T_mel, C), the decoder output.
Shape: (B, T_mel, C). dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
dec_slf_attn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
""" """
dec_slf_attn_list = [] dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
......
...@@ -24,12 +24,27 @@ class Encoder(dg.Layer): ...@@ -24,12 +24,27 @@ class Encoder(dg.Layer):
n_layers, n_layers,
n_head, n_head,
d_k, d_k,
d_v, d_q,
d_model, d_model,
d_inner, d_inner,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=0.1): dropout=0.1):
"""Encoder layer of FastSpeech.
Args:
n_src_vocab (int): the number of source vocabulary.
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super(Encoder, self).__init__() super(Encoder, self).__init__()
n_position = len_max_seq + 1 n_position = len_max_seq + 1
self.n_head = n_head self.n_head = n_head
...@@ -53,7 +68,7 @@ class Encoder(dg.Layer): ...@@ -53,7 +68,7 @@ class Encoder(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=dropout) for _ in range(n_layers) dropout=dropout) for _ in range(n_layers)
...@@ -63,25 +78,20 @@ class Encoder(dg.Layer): ...@@ -63,25 +78,20 @@ class Encoder(dg.Layer):
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
""" """
Encoder layer of FastSpeech. Encode text sequence.
Args: Args:
character (Variable): The input text characters. character (Variable): shape(B, T_text), dtype float32, the input text characters,
Shape: (B, T_text), T_text means the timesteps of input characters, where T_text means the timesteps of input characters,
dtype: float32. text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
text_pos (Variable): The input text position. non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
Shape: (B, T_text), dtype: int64. slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
non_pad_mask (Variable): the mask with non pad. the mask of input characters. Defaults to None.
Shape: (B, T_text, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
Returns: Returns:
enc_output (Variable), the encoder output. Shape(B, T_text, C) enc_output (Variable): shape(B, T_text, C), the encoder output.
non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1) non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list[Variable]), the encoder self attention list. enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
Len: n_layers.
""" """
enc_slf_attn_list = [] enc_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
......
...@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder ...@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder
class FastSpeech(dg.Layer): class FastSpeech(dg.Layer):
def __init__(self, cfg): def __init__(self, cfg):
" FastSpeech" """FastSpeech model.
Args:
cfg: the yaml configs used in FastSpeech model.
"""
super(FastSpeech, self).__init__() super(FastSpeech, self).__init__()
self.encoder = Encoder( self.encoder = Encoder(
...@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer): ...@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer):
n_layers=cfg['encoder_n_layer'], n_layers=cfg['encoder_n_layer'],
n_head=cfg['encoder_head'], n_head=cfg['encoder_head'],
d_k=cfg['fs_hidden_size'] // cfg['encoder_head'], d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
d_v=cfg['fs_hidden_size'] // cfg['encoder_head'], d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
d_model=cfg['fs_hidden_size'], d_model=cfg['fs_hidden_size'],
d_inner=cfg['encoder_conv1d_filter_size'], d_inner=cfg['encoder_conv1d_filter_size'],
fft_conv1d_kernel=cfg['fft_conv1d_filter'], fft_conv1d_kernel=cfg['fft_conv1d_filter'],
...@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer): ...@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer):
n_layers=cfg['decoder_n_layer'], n_layers=cfg['decoder_n_layer'],
n_head=cfg['decoder_head'], n_head=cfg['decoder_head'],
d_k=cfg['fs_hidden_size'] // cfg['decoder_head'], d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
d_v=cfg['fs_hidden_size'] // cfg['decoder_head'], d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
d_model=cfg['fs_hidden_size'], d_model=cfg['fs_hidden_size'],
d_inner=cfg['decoder_conv1d_filter_size'], d_inner=cfg['decoder_conv1d_filter_size'],
fft_conv1d_kernel=cfg['fft_conv1d_filter'], fft_conv1d_kernel=cfg['fft_conv1d_filter'],
...@@ -88,39 +92,31 @@ class FastSpeech(dg.Layer): ...@@ -88,39 +92,31 @@ class FastSpeech(dg.Layer):
length_target=None, length_target=None,
alpha=1.0): alpha=1.0):
""" """
FastSpeech model. Compute mel output from text character.
Args: Args:
character (Variable): The input text characters. character (Variable): shape(B, T_text), dtype float32, the input text characters,
Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32. where T_text means the timesteps of input characters,
text_pos (Variable): The input text position. text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
Shape: (B, T_text), dtype: int64. mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
mel_pos (Variable, optional): The spectrum position. where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64. enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
enc_non_pad_mask (Variable): the mask with non pad. dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
Shape: (B, T_text, 1), enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
dtype: int64. the mask of input characters. Defaults to None.
dec_non_pad_mask (Variable): the mask with non pad. slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
Shape: (B, T_mel, 1), the mask of mel spectrum. Defaults to None.
dtype: int64. length_target (Variable, optional): shape(B, T_text), dtype int64,
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None. the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed. Defaults to 1.0. mel, thereby controlling the voice speed. Defaults to 1.0.
Returns: Returns:
mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C), mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C). mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
duration_predictor_output (Variable), the duration of phoneme compute with duration predictor. duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor.
Shape: (B, T_text). enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list.
enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers. dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers.
""" """
encoder_output, enc_slf_attn_list = self.encoder( encoder_output, enc_slf_attn_list = self.encoder(
......
...@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer): ...@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
filter_size, filter_size,
padding, padding,
dropout=0.2): dropout=0.2):
"""Feed forward structure based on self-attention.
Args:
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
filter_size (int): the conv kernel size.
padding (int): the conv padding size.
dropout (float, optional): dropout probability. Defaults to 0.2.
"""
super(FFTBlock, self).__init__() super(FFTBlock, self).__init__()
self.slf_attn = MultiheadAttention( self.slf_attn = MultiheadAttention(
d_model, d_model,
d_k, d_k,
d_v, d_q,
num_head=n_head, num_head=n_head,
is_bias=True, is_bias=True,
dropout=dropout, dropout=dropout,
...@@ -48,20 +60,18 @@ class FFTBlock(dg.Layer): ...@@ -48,20 +60,18 @@ class FFTBlock(dg.Layer):
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None): def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
""" """
Feed Forward Transformer block in FastSpeech. Feed forward block of FastSpeech
Args: Args:
enc_input (Variable): The embedding characters input. enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input,
Shape: (B, T, C), T means the timesteps of input, dtype: float32. where T means the timesteps of input.
non_pad_mask (Variable): The mask of sequence. non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
Shape: (B, T, 1), dtype: int64. slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None. where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None.
Shape(B, len_q, len_k), len_q means the sequence length of query,
len_k means the sequence length of key, dtype: int64.
Returns: Returns:
output (Variable), the output after self-attention & ffn. Shape: (B, T, C). output (Variable): shape(B, T, C), the output after self-attention & ffn.
slf_attn (Variable), the self attention. Shape: (B * n_head, T, T), slf_attn (Variable): shape(B * n_head, T, T), the self attention.
""" """
output, slf_attn = self.slf_attn( output, slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_input, enc_input, enc_input, mask=slf_attn_mask)
......
...@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D ...@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D
class LengthRegulator(dg.Layer): class LengthRegulator(dg.Layer):
def __init__(self, input_size, out_channels, filter_size, dropout=0.1): def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
"""Length Regulator block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size of duration predictor.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(LengthRegulator, self).__init__() super(LengthRegulator, self).__init__()
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
input_size=input_size, input_size=input_size,
...@@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer): ...@@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer):
def forward(self, x, alpha=1.0, target=None): def forward(self, x, alpha=1.0, target=None):
""" """
Length Regulator block in FastSpeech. Compute length of mel from encoder output use TransformerTTS attention
Args: Args:
x (Variable): The encoder output. x (Variable): shape(B, T, C), dtype float32, the encoder output.
Shape: (B, T, C), dtype: float32. alpha (float32, optional): the hyperparameter to determine the length of
alpha (float32, optional): The hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0. the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS. target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64. Defaults to None.
Returns: Returns:
output (Variable), the output after exppand. Shape: (B, T, C), output (Variable): shape(B, T, C), the output after exppand.
duration_predictor_output (Variable), the output of duration predictor. duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor.
Shape: (B, T, C).
""" """
duration_predictor_output = self.duration_predictor(x) duration_predictor_output = self.duration_predictor(x)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
...@@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer): ...@@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer):
class DurationPredictor(dg.Layer): class DurationPredictor(dg.Layer):
def __init__(self, input_size, out_channels, filter_size, dropout=0.1): def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
"""Duration Predictor block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(DurationPredictor, self).__init__() super(DurationPredictor, self).__init__()
self.input_size = input_size self.input_size = input_size
self.out_channels = out_channels self.out_channels = out_channels
...@@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer): ...@@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer):
def forward(self, encoder_output): def forward(self, encoder_output):
""" """
Duration Predictor block in FastSpeech. Predict the duration of each character.
Args: Args:
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output. encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output.
Returns: Returns:
out (Variable), Shape(B, T, C), the output of duration predictor. out (Variable): shape(B, T, C), the output of duration predictor.
""" """
# encoder_output.shape(N, T, C) # encoder_output.shape(N, T, C)
out = layers.transpose(encoder_output, [0, 2, 1]) out = layers.transpose(encoder_output, [0, 2, 1])
......
...@@ -30,6 +30,17 @@ class CBHG(dg.Layer): ...@@ -30,6 +30,17 @@ class CBHG(dg.Layer):
num_gru_layers=2, num_gru_layers=2,
max_pool_kernel_size=2, max_pool_kernel_size=2,
is_post=False): is_post=False):
"""CBHG Module
Args:
hidden_size (int): dimension of hidden unit.
batch_size (int): batch size of input.
K (int, optional): number of convolution banks. Defaults to 16.
projection_size (int, optional): dimension of projection unit. Defaults to 256.
num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2.
max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2
is_post (bool, optional): whether post processing or not. Defaults to False.
"""
super(CBHG, self).__init__() super(CBHG, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -169,13 +180,13 @@ class CBHG(dg.Layer): ...@@ -169,13 +180,13 @@ class CBHG(dg.Layer):
def forward(self, input_): def forward(self, input_):
""" """
CBHG Module Convert linear spectrum to Mel spectrum.
Args: Args:
input_(Variable): The sequentially input. input_ (Variable): shape(B, C, T), dtype float32, the sequentially input.
Shape: (B, C, T), dtype: float32.
Returns: Returns:
(Variable): the CBHG output. out (Variable): shape(B, C, T), the CBHG output.
""" """
conv_list = [] conv_list = []
...@@ -217,6 +228,12 @@ class CBHG(dg.Layer): ...@@ -217,6 +228,12 @@ class CBHG(dg.Layer):
class Highwaynet(dg.Layer): class Highwaynet(dg.Layer):
def __init__(self, num_units, num_layers=4): def __init__(self, num_units, num_layers=4):
"""Highway network
Args:
num_units (int): dimension of hidden unit.
num_layers (int, optional): number of highway layers. Defaults to 4.
"""
super(Highwaynet, self).__init__() super(Highwaynet, self).__init__()
self.num_units = num_units self.num_units = num_units
self.num_layers = num_layers self.num_layers = num_layers
...@@ -250,13 +267,13 @@ class Highwaynet(dg.Layer): ...@@ -250,13 +267,13 @@ class Highwaynet(dg.Layer):
def forward(self, input_): def forward(self, input_):
""" """
Highway network Compute result of Highway network.
Args:
input_(Variable): The sequentially input.
Shape: (B, T, C), dtype: float32.
Args:
input_(Variable): shape(B, T, C), dtype float32, the sequentially input.
Returns: Returns:
(Variable): the Highway output. out(Variable): the Highway output.
""" """
out = input_ out = input_
......
...@@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet ...@@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
class Decoder(dg.Layer): class Decoder(dg.Layer):
def __init__(self, num_hidden, config, num_head=4, n_layers=3): def __init__(self, num_hidden, config, num_head=4, n_layers=3):
"""Decoder layer of TransformerTTS.
Args:
num_hidden (int): the number of source vocabulary.
config: the yaml configs used in decoder.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -109,38 +117,26 @@ class Decoder(dg.Layer): ...@@ -109,38 +117,26 @@ class Decoder(dg.Layer):
m_self_mask=None, m_self_mask=None,
zero_mask=None): zero_mask=None):
""" """
Decoder layer of TransformerTTS. Compute decoder outputs.
Args: Args:
key (Variable): The input key of decoder. key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder,
Shape: (B, T_text, C), T_text means the timesteps of input text, where T_text means the timesteps of input text,
dtype: float32. value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder.
value (Variable): The . input value of decoder. query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
Shape: (B, T_text, C), dtype: float32. where T_mel means the timesteps of input spectrum,
query (Variable): The input query of decoder. positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum, mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
dtype: float32. m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
positional (Variable): The spectrum position. m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel), dtype: int64. zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns: Returns:
mel_out (Variable): the decoder output after mel linear projection. mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
Shape: (B, T_mel, C). out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
out (Variable): the decoder output after post mel network. stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output.
Shape: (B, T_mel, C). attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list.
stop_tokens (Variable): the stop tokens of output. selfattn_list (list[Variable]): len(n_layers), the decoder self attention list.
Shape: (B, T_mel, 1)
attn_list (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
selfattn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
""" """
# get decoder mask with triangular matrix # get decoder mask with triangular matrix
......
...@@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet ...@@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
class Encoder(dg.Layer): class Encoder(dg.Layer):
def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3): def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3):
"""Encoder layer of TransformerTTS.
Args:
embedding_size (int): the size of position embedding.
num_hidden (int): the size of hidden layer in network.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -58,23 +66,18 @@ class Encoder(dg.Layer): ...@@ -58,23 +66,18 @@ class Encoder(dg.Layer):
def forward(self, x, positional, mask=None, query_mask=None): def forward(self, x, positional, mask=None, query_mask=None):
""" """
Encoder layer of TransformerTTS. Encode text sequence.
Args: Args:
x (Variable): The input character. x (Variable): shape(B, T_text), dtype float32, the input character,
Shape: (B, T_text), T_text means the timesteps of input text, where T_text means the timesteps of input text,
dtype: float32. positional (Variable): shape(B, T_text), dtype int64, the characters position.
positional (Variable): The characters position. mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
Shape: (B, T_text), dtype: int64. query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
Returns: Returns:
x (Variable): the encoder output. x (Variable): shape(B, T_text, C), the encoder output.
Shape: (B, T_text, C). attentions (list[Variable]): len(n_layers), the encoder self attention list.
attentions (list[Variable]): the encoder self attention list.
Len: n_layers.
""" """
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
......
...@@ -22,6 +22,13 @@ import numpy as np ...@@ -22,6 +22,13 @@ import numpy as np
class EncoderPrenet(dg.Layer): class EncoderPrenet(dg.Layer):
def __init__(self, embedding_size, num_hidden, use_cudnn=True): def __init__(self, embedding_size, num_hidden, use_cudnn=True):
""" Encoder prenet layer of TransformerTTS.
Args:
embedding_size (int): the size of embedding.
num_hidden (int): the size of hidden layer in network.
use_cudnn (bool, optional): use cudnn or not. Defaults to True.
"""
super(EncoderPrenet, self).__init__() super(EncoderPrenet, self).__init__()
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.num_hidden = num_hidden self.num_hidden = num_hidden
...@@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer): ...@@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer):
def forward(self, x): def forward(self, x):
""" """
Encoder prenet layer of TransformerTTS. Prepare encoder input.
Args: Args:
x (Variable): The input character. x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
Returns: Returns:
(Variable): the encoder prenet output. Shape: (B, T_text, C). (Variable): shape(B, T_text, C), the encoder prenet output.
""" """
x = self.embedding(x) x = self.embedding(x)
......
...@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer): ...@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer):
use_cudnn=True, use_cudnn=True,
dropout=0.1, dropout=0.1,
batchnorm_last=False): batchnorm_last=False):
"""Decocder post conv net of TransformerTTS.
Args:
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
num_hidden (int, optional): the size of hidden layer in network. Defaults to 512.
filter_size (int, optional): the filter size of Conv. Defaults to 5.
padding (int, optional): the padding size of Conv. Defaults to 0.
num_conv (int, optional): the num of Conv layers in network. Defaults to 5.
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False.
"""
super(PostConvNet, self).__init__() super(PostConvNet, self).__init__()
self.dropout = dropout self.dropout = dropout
...@@ -93,13 +106,13 @@ class PostConvNet(dg.Layer): ...@@ -93,13 +106,13 @@ class PostConvNet(dg.Layer):
def forward(self, input): def forward(self, input):
""" """
Decocder Post Conv Net of TransformerTTS. Compute the mel spectrum.
Args: Args:
input (Variable): The result of mel linear projection. input (Variable): shape(B, T, C), dtype float32, the result of mel linear projection.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
(Variable): the result after postconvnet. Shape: (B, T, C), output (Variable): shape(B, T, C), the result after postconvnet.
""" """
input = layers.transpose(input, [0, 2, 1]) input = layers.transpose(input, [0, 2, 1])
......
...@@ -19,6 +19,14 @@ import paddle.fluid.layers as layers ...@@ -19,6 +19,14 @@ import paddle.fluid.layers as layers
class PreNet(dg.Layer): class PreNet(dg.Layer):
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
"""Prenet before passing through the network.
Args:
input_size (int): the input channel size.
hidden_size (int): the size of hidden layer in network.
output_size (int): the output channel size.
dropout_rate (float, optional): dropout probability. Defaults to 0.2.
"""
super(PreNet, self).__init__() super(PreNet, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -44,20 +52,20 @@ class PreNet(dg.Layer): ...@@ -44,20 +52,20 @@ class PreNet(dg.Layer):
def forward(self, x): def forward(self, x):
""" """
Pre Net before passing through the network. Prepare network input.
Args: Args:
x (Variable): The input value. x (Variable): shape(B, T, C), dtype float32, the input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
(Variable), the result after pernet. Shape: (B, T, C), output (Variable): shape(B, T, C), the result after pernet.
""" """
x = layers.dropout( x = layers.dropout(
layers.relu(self.linear1(x)), layers.relu(self.linear1(x)),
self.dropout_rate, self.dropout_rate,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
x = layers.dropout( output = layers.dropout(
layers.relu(self.linear2(x)), layers.relu(self.linear2(x)),
self.dropout_rate, self.dropout_rate,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
return x return output
...@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder ...@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder
class TransformerTTS(dg.Layer): class TransformerTTS(dg.Layer):
def __init__(self, config): def __init__(self, config):
"""TransformerTTS model.
Args:
config: the yaml configs used in TransformerTTS model.
"""
super(TransformerTTS, self).__init__() super(TransformerTTS, self).__init__()
self.encoder = Encoder(config['embedding_size'], config['hidden_size']) self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
self.decoder = Decoder(config['hidden_size'], config) self.decoder = Decoder(config['hidden_size'], config)
...@@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer): ...@@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer):
dec_query_mask=None): dec_query_mask=None):
""" """
TransformerTTS network. TransformerTTS network.
Args: Args:
characters (Variable): The input character. characters (Variable): shape(B, T_text), dtype float32, the input character,
Shape: (B, T_text), T_text means the timesteps of input text, where T_text means the timesteps of input text,
dtype: float32. mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
mel_input (Variable): The input query of decoder. where T_mel means the timesteps of input spectrum,
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum, pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
dtype: float32. dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
pos_text (Variable): The characters position. mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
Shape: (B, T_text), dtype: int64. enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
dec_slf_mask (Variable): The spectrum position. enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
Shape: (B, T_mel), dtype: int64. dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
mask (Variable): the mask of decoder self attention. dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, T_mel), dtype: int64. enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns: Returns:
mel_output (Variable): the decoder output after mel linear projection. mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
Shape: (B, T_mel, C). postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network.
postnet_output (Variable): the decoder output after post mel network. stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output.
Shape: (B, T_mel, C). attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list.
stop_preds (Variable): the stop tokens of output. attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
Shape: (B, T_mel, 1) attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
attn_probs (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
attns_enc (list[Variable]): the encoder self attention list.
Len: n_layers.
attns_dec (list[Variable]): the decoder self attention list.
Len: n_layers.
""" """
key, attns_enc = self.encoder( key, attns_enc = self.encoder(
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
......
...@@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG ...@@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG
class Vocoder(dg.Layer): class Vocoder(dg.Layer):
def __init__(self, config, batch_size): def __init__(self, config, batch_size):
"""CBHG Network (mel -> linear)
Args:
config: the yaml configs used in Vocoder model.
batch_size (int): the batch size of input.
"""
super(Vocoder, self).__init__() super(Vocoder, self).__init__()
self.pre_proj = Conv1D( self.pre_proj = Conv1D(
num_channels=config['audio']['num_mels'], num_channels=config['audio']['num_mels'],
...@@ -33,14 +39,13 @@ class Vocoder(dg.Layer): ...@@ -33,14 +39,13 @@ class Vocoder(dg.Layer):
def forward(self, mel): def forward(self, mel):
""" """
CBHG Network (mel -> linear) Compute mel spectrum to linear spectrum.
Args: Args:
mel (Variable): The input mel spectrum. mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum.
Shape: (B, C, T), dtype: float32.
Returns: Returns:
(Variable): the linear output. mag_pred (Variable): shape(B, T, C), the linear output.
Shape: (B, T, C).
""" """
mel = layers.transpose(mel, [0, 2, 1]) mel = layers.transpose(mel, [0, 2, 1])
mel = self.pre_proj(mel) mel = self.pre_proj(mel)
......
...@@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer): ...@@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer):
Dynamic GRU block. Dynamic GRU block.
Args: Args:
input (Variable): The input value. input (Variable): shape(B, T, C), dtype float32, the input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
output (Variable), the result compute by GRU. Shape: (B, T, C). output (Variable): shape(B, T, C), the result compute by GRU.
""" """
hidden = self.h_0 hidden = self.h_0
res = [] res = []
......
...@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D ...@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D
class PositionwiseFeedForward(dg.Layer): class PositionwiseFeedForward(dg.Layer):
''' A two-feed-forward-layer module '''
def __init__(self, def __init__(self,
d_in, d_in,
num_hidden, num_hidden,
...@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer): ...@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer):
padding=0, padding=0,
use_cudnn=True, use_cudnn=True,
dropout=0.1): dropout=0.1):
"""A two-feed-forward-layer module.
Args:
d_in (int): the size of input channel.
num_hidden (int): the size of hidden layer in network.
filter_size (int): the filter size of Conv
padding (int, optional): the padding size of Conv. Defaults to 0.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(PositionwiseFeedForward, self).__init__() super(PositionwiseFeedForward, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer): ...@@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer):
def forward(self, input): def forward(self, input):
""" """
Feed Forward Network. Compute feed forward network result.
Args: Args:
input (Variable): The input value. input (Variable): shape(B, T, C), dtype float32, the input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
output (Variable), the result after FFN. Shape: (B, T, C). output (Variable): shape(B, T, C), the result after FFN.
""" """
x = layers.transpose(input, [0, 2, 1]) x = layers.transpose(input, [0, 2, 1])
#FFN Networt #FFN Networt
......
...@@ -50,6 +50,11 @@ class Linear(dg.Layer): ...@@ -50,6 +50,11 @@ class Linear(dg.Layer):
class ScaledDotProductAttention(dg.Layer): class ScaledDotProductAttention(dg.Layer):
def __init__(self, d_key): def __init__(self, d_key):
"""Scaled dot product attention module.
Args:
d_key (int): the dim of key in multihead attention.
"""
super(ScaledDotProductAttention, self).__init__() super(ScaledDotProductAttention, self).__init__()
self.d_key = d_key self.d_key = d_key
...@@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer): ...@@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer):
query_mask=None, query_mask=None,
dropout=0.1): dropout=0.1):
""" """
Scaled Dot Product Attention. Compute scaled dot product attention.
Args: Args:
key (Variable): The input key of scaled dot product attention. key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
Shape: (B, T, C), dtype: float32. value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
value (Variable): The input value of scaled dot product attention. query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
Shape: (B, T, C), dtype: float32. mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
query (Variable): The input query of scaled dot product attention. query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
Shape: (B, T, C), dtype: float32. dropout (float32, optional): the probability of dropout. Defaults to 0.1.
mask (Variable, optional): The mask of key. Defaults to None.
Shape(B, T_q, T_k), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape(B, T_q, T_q), dtype: float32.
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
Returns: Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention. result (Variable): shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key. attention (Variable): shape(n_head * B, T, C), the attention of key.
""" """
# Compute attention score # Compute attention score
attention = layers.matmul( attention = layers.matmul(
...@@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer): ...@@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer):
is_bias=False, is_bias=False,
dropout=0.1, dropout=0.1,
is_concat=True): is_concat=True):
"""Multihead Attention.
Args:
num_hidden (int): the number of hidden layer in network.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
num_head (int, optional): the head number of multihead attention. Defaults to 4.
is_bias (bool, optional): whether have bias in linear layers. Default to False.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
is_concat (bool, optional): whether concat query and result. Default to True.
"""
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer): ...@@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer):
def forward(self, key, value, query_input, mask=None, query_mask=None): def forward(self, key, value, query_input, mask=None, query_mask=None):
""" """
Multihead Attention. Compute attention.
Args: Args:
key (Variable): The input key of attention. key (Variable): shape(B, T, C), dtype float32, the input key of attention.
Shape: (B, T, C), dtype: float32. value (Variable): shape(B, T, C), dtype float32, the input value of attention.
value (Variable): The input value of attention. query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
Shape: (B, T, C), dtype: float32. mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
query_input (Variable): The input query of attention. query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
Returns: Returns:
result (Variable), the result of mutihead attention. Shape: (B, T, C). result (Variable): shape(B, T, C), the result of mutihead attention.
attention (Variable), the attention of key and query. Shape: (num_head * B, T, C) attention (Variable): shape(num_head * B, T, C), the attention of key and query.
""" """
batch_size = key.shape[0] batch_size = key.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册