提交 c2cf81c3 编写于 作者: L liuyibing01

Merge branch 'master' into 'master'

add docstring to transformer_tts and fastspeech

See merge request !36
...@@ -6,6 +6,7 @@ python -u synthesis.py \ ...@@ -6,6 +6,7 @@ python -u synthesis.py \
--checkpoint_path='checkpoint/' \ --checkpoint_path='checkpoint/' \
--fastspeech_step=71000 \ --fastspeech_step=71000 \
--log_dir='./log' \ --log_dir='./log' \
--config_path='configs/synthesis.yaml' \
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in synthesis!" echo "Failed in synthesis!"
......
...@@ -13,7 +13,7 @@ python -u train.py \ ...@@ -13,7 +13,7 @@ python -u train.py \
--transformer_step=160000 \ --transformer_step=160000 \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/fastspeech.yaml' \ --config_path='configs/fastspeech.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--fastspeech_step=97000 \ #--fastspeech_step=97000 \
......
...@@ -84,7 +84,7 @@ def synthesis(text_input, args): ...@@ -84,7 +84,7 @@ def synthesis(text_input, args):
dec_slf_mask = get_triu_tensor( dec_slf_mask = get_triu_tensor(
mel_input.numpy(), mel_input.numpy()).astype(np.float32) mel_input.numpy(), mel_input.numpy()).astype(np.float32)
dec_slf_mask = fluid.layers.cast( dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask == 0), np.float32) dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
...@@ -157,6 +157,5 @@ if __name__ == '__main__': ...@@ -157,6 +157,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Synthesis model") parser = argparse.ArgumentParser(description="Synthesis model")
add_config_options_to_parser(parser) add_config_options_to_parser(parser)
args = parser.parse_args() args = parser.parse_args()
synthesis( synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
"They emphasized the necessity that the information now being furnished be handled with judgment and care.", args)
args)
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# train model # train model
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
python -u synthesis.py \ python -u synthesis.py \
--max_len=600 \ --max_len=300 \
--transformer_step=160000 \ --transformer_step=120000 \
--vocoder_step=90000 \ --vocoder_step=100000 \
--use_gpu=1 \ --use_gpu=1 \
--checkpoint_path='./checkpoint' \ --checkpoint_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--sample_path='./sample' \ --sample_path='./sample' \
--config_path='config/synthesis.yaml' \ --config_path='configs/synthesis.yaml' \
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"
......
...@@ -14,7 +14,7 @@ python -u train_transformer.py \ ...@@ -14,7 +14,7 @@ python -u train_transformer.py \
--data_path='../../dataset/LJSpeech-1.1' \ --data_path='../../dataset/LJSpeech-1.1' \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/train_transformer.yaml' \ --config_path='configs/train_transformer.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--transformer_step=160000 \ #--transformer_step=160000 \
......
...@@ -12,7 +12,7 @@ python -u train_vocoder.py \ ...@@ -12,7 +12,7 @@ python -u train_vocoder.py \
--data_path='../../dataset/LJSpeech-1.1' \ --data_path='../../dataset/LJSpeech-1.1' \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/train_vocoder.yaml' \ --config_path='configs/train_vocoder.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--vocoder_step=27000 \ #--vocoder_step=27000 \
......
...@@ -23,12 +23,26 @@ class Decoder(dg.Layer): ...@@ -23,12 +23,26 @@ class Decoder(dg.Layer):
n_layers, n_layers,
n_head, n_head,
d_k, d_k,
d_v, d_q,
d_model, d_model,
d_inner, d_inner,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=0.1): dropout=0.1):
"""Decoder layer of FastSpeech.
Args:
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super(Decoder, self).__init__() super(Decoder, self).__init__()
n_position = len_max_seq + 1 n_position = len_max_seq + 1
...@@ -48,7 +62,7 @@ class Decoder(dg.Layer): ...@@ -48,7 +62,7 @@ class Decoder(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=dropout) for _ in range(n_layers) dropout=dropout) for _ in range(n_layers)
...@@ -58,16 +72,20 @@ class Decoder(dg.Layer): ...@@ -58,16 +72,20 @@ class Decoder(dg.Layer):
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
""" """
Decoder layer of FastSpeech. Compute decoder outputs.
Args: Args:
enc_seq (Variable), Shape(B, text_T, C), dtype: float32. enc_seq (Variable): shape(B, T_text, C), dtype float32,
The output of length regulator. the output of length regulator, where T_text means the timesteps of input text,
enc_pos (Variable, optional): Shape(B, T_mel), enc_pos (Variable): shape(B, T_mel), dtype int64,
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. the spectrum position, where T_mel means the timesteps of input spectrum,
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
the mask of mel spectrum. Defaults to None.
Returns: Returns:
dec_output (Variable), Shape(B, mel_T, C), the decoder output. dec_output (Variable): shape(B, T_mel, C), the decoder output.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
""" """
dec_slf_attn_list = [] dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
......
...@@ -24,12 +24,27 @@ class Encoder(dg.Layer): ...@@ -24,12 +24,27 @@ class Encoder(dg.Layer):
n_layers, n_layers,
n_head, n_head,
d_k, d_k,
d_v, d_q,
d_model, d_model,
d_inner, d_inner,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=0.1): dropout=0.1):
"""Encoder layer of FastSpeech.
Args:
n_src_vocab (int): the number of source vocabulary.
len_max_seq (int): the max mel len of sequence.
n_layers (int): the layers number of FFTBlock.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
fft_conv1d_padding (int): the conv padding size in FFTBlock.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
"""
super(Encoder, self).__init__() super(Encoder, self).__init__()
n_position = len_max_seq + 1 n_position = len_max_seq + 1
self.n_head = n_head self.n_head = n_head
...@@ -53,7 +68,7 @@ class Encoder(dg.Layer): ...@@ -53,7 +68,7 @@ class Encoder(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
fft_conv1d_kernel, fft_conv1d_kernel,
fft_conv1d_padding, fft_conv1d_padding,
dropout=dropout) for _ in range(n_layers) dropout=dropout) for _ in range(n_layers)
...@@ -63,18 +78,20 @@ class Encoder(dg.Layer): ...@@ -63,18 +78,20 @@ class Encoder(dg.Layer):
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
""" """
Encoder layer of FastSpeech. Encode text sequence.
Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text
characters. T_text means the timesteps of input characters.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
position. T_text means the timesteps of input characters.
Args:
character (Variable): shape(B, T_text), dtype float32, the input text characters,
where T_text means the timesteps of input characters,
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
the mask of input characters. Defaults to None.
Returns: Returns:
enc_output (Variable), Shape(B, text_T, C), the encoder output. enc_output (Variable): shape(B, T_text, C), the encoder output.
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad. non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list. enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
""" """
enc_slf_attn_list = [] enc_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
......
...@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder ...@@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder
class FastSpeech(dg.Layer): class FastSpeech(dg.Layer):
def __init__(self, cfg): def __init__(self, cfg):
" FastSpeech" """FastSpeech model.
Args:
cfg: the yaml configs used in FastSpeech model.
"""
super(FastSpeech, self).__init__() super(FastSpeech, self).__init__()
self.encoder = Encoder( self.encoder = Encoder(
...@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer): ...@@ -34,7 +38,7 @@ class FastSpeech(dg.Layer):
n_layers=cfg['encoder_n_layer'], n_layers=cfg['encoder_n_layer'],
n_head=cfg['encoder_head'], n_head=cfg['encoder_head'],
d_k=cfg['fs_hidden_size'] // cfg['encoder_head'], d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
d_v=cfg['fs_hidden_size'] // cfg['encoder_head'], d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
d_model=cfg['fs_hidden_size'], d_model=cfg['fs_hidden_size'],
d_inner=cfg['encoder_conv1d_filter_size'], d_inner=cfg['encoder_conv1d_filter_size'],
fft_conv1d_kernel=cfg['fft_conv1d_filter'], fft_conv1d_kernel=cfg['fft_conv1d_filter'],
...@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer): ...@@ -50,7 +54,7 @@ class FastSpeech(dg.Layer):
n_layers=cfg['decoder_n_layer'], n_layers=cfg['decoder_n_layer'],
n_head=cfg['decoder_head'], n_head=cfg['decoder_head'],
d_k=cfg['fs_hidden_size'] // cfg['decoder_head'], d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
d_v=cfg['fs_hidden_size'] // cfg['decoder_head'], d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
d_model=cfg['fs_hidden_size'], d_model=cfg['fs_hidden_size'],
d_inner=cfg['decoder_conv1d_filter_size'], d_inner=cfg['decoder_conv1d_filter_size'],
fft_conv1d_kernel=cfg['fft_conv1d_filter'], fft_conv1d_kernel=cfg['fft_conv1d_filter'],
...@@ -82,34 +86,37 @@ class FastSpeech(dg.Layer): ...@@ -82,34 +86,37 @@ class FastSpeech(dg.Layer):
text_pos, text_pos,
enc_non_pad_mask, enc_non_pad_mask,
dec_non_pad_mask, dec_non_pad_mask,
mel_pos=None,
enc_slf_attn_mask=None, enc_slf_attn_mask=None,
dec_slf_attn_mask=None, dec_slf_attn_mask=None,
mel_pos=None,
length_target=None, length_target=None,
alpha=1.0): alpha=1.0):
""" """
FastSpeech model. Compute mel output from text character.
Args: Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text character (Variable): shape(B, T_text), dtype float32, the input text characters,
characters. T_text means the timesteps of input characters. where T_text means the timesteps of input characters,
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
position. T_text means the timesteps of input characters. mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
mel_pos (Variable, optional): Shape(B, T_mel), where T_mel means the timesteps of input spectrum,
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
length_target (Variable, optional): Shape(B, T_text), dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
dtype: int64. The duration of phoneme compute from pretrained transformerTTS. enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
alpha (Constant): the mask of input characters. Defaults to None.
dtype: float32. The hyperparameter to determine the length of the expanded sequence slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
mel, thereby controlling the voice speed. the mask of mel spectrum. Defaults to None.
length_target (Variable, optional): shape(B, T_text), dtype int64,
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed. Defaults to 1.0.
Returns: Returns:
mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet. mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet. mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor.
with duration predictor. enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list.
enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list. dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
""" """
encoder_output, enc_slf_attn_list = self.encoder( encoder_output, enc_slf_attn_list = self.encoder(
...@@ -118,7 +125,6 @@ class FastSpeech(dg.Layer): ...@@ -118,7 +125,6 @@ class FastSpeech(dg.Layer):
enc_non_pad_mask, enc_non_pad_mask,
slf_attn_mask=enc_slf_attn_mask) slf_attn_mask=enc_slf_attn_mask)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
length_regulator_output, duration_predictor_output = self.length_regulator( length_regulator_output, duration_predictor_output = self.length_regulator(
encoder_output, target=length_target, alpha=alpha) encoder_output, target=length_target, alpha=alpha)
decoder_output, dec_slf_attn_list = self.decoder( decoder_output, dec_slf_attn_list = self.decoder(
......
...@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer): ...@@ -26,15 +26,27 @@ class FFTBlock(dg.Layer):
d_inner, d_inner,
n_head, n_head,
d_k, d_k,
d_v, d_q,
filter_size, filter_size,
padding, padding,
dropout=0.2): dropout=0.2):
"""Feed forward structure based on self-attention.
Args:
d_model (int): the dim of hidden layer in multihead attention.
d_inner (int): the dim of hidden layer in ffn.
n_head (int): the head number of multihead attention.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
filter_size (int): the conv kernel size.
padding (int): the conv padding size.
dropout (float, optional): dropout probability. Defaults to 0.2.
"""
super(FFTBlock, self).__init__() super(FFTBlock, self).__init__()
self.slf_attn = MultiheadAttention( self.slf_attn = MultiheadAttention(
d_model, d_model,
d_k, d_k,
d_v, d_q,
num_head=n_head, num_head=n_head,
is_bias=True, is_bias=True,
dropout=dropout, dropout=dropout,
...@@ -48,18 +60,18 @@ class FFTBlock(dg.Layer): ...@@ -48,18 +60,18 @@ class FFTBlock(dg.Layer):
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None): def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
""" """
Feed Forward Transformer block in FastSpeech. Feed forward block of FastSpeech
Args: Args:
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input. enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input,
T means the timesteps of input. where T means the timesteps of input.
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence. non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention. slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
len_q means the sequence length of query, len_k means the sequence length of key. where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None.
Returns: Returns:
output (Variable), Shape(B, T, C), the output after self-attention & ffn. output (Variable): shape(B, T, C), the output after self-attention & ffn.
slf_attn (Variable), Shape(B * n_head, T, T), the self attention. slf_attn (Variable): shape(B * n_head, T, T), the self attention.
""" """
output, slf_attn = self.slf_attn( output, slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_input, enc_input, enc_input, mask=slf_attn_mask)
......
...@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D ...@@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D
class LengthRegulator(dg.Layer): class LengthRegulator(dg.Layer):
def __init__(self, input_size, out_channels, filter_size, dropout=0.1): def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
"""Length Regulator block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size of duration predictor.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(LengthRegulator, self).__init__() super(LengthRegulator, self).__init__()
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
input_size=input_size, input_size=input_size,
...@@ -66,18 +74,18 @@ class LengthRegulator(dg.Layer): ...@@ -66,18 +74,18 @@ class LengthRegulator(dg.Layer):
def forward(self, x, alpha=1.0, target=None): def forward(self, x, alpha=1.0, target=None):
""" """
Length Regulator block in FastSpeech. Compute length of mel from encoder output use TransformerTTS attention
Args: Args:
x (Variable): Shape(B, T, C), dtype: float32. The encoder output. x (Variable): shape(B, T, C), dtype float32, the encoder output.
alpha (Constant): dtype: float32. The hyperparameter to determine the length of alpha (float32, optional): the hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed. the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
target (Variable): (Variable, optional): Shape(B, T_text), target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS.
dtype: int64. The duration of phoneme compute from pretrained transformerTTS. Defaults to None.
Returns: Returns:
output (Variable), Shape(B, T, C), the output after exppand. output (Variable): shape(B, T, C), the output after exppand.
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor. duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor.
""" """
duration_predictor_output = self.duration_predictor(x) duration_predictor_output = self.duration_predictor(x)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
...@@ -93,6 +101,14 @@ class LengthRegulator(dg.Layer): ...@@ -93,6 +101,14 @@ class LengthRegulator(dg.Layer):
class DurationPredictor(dg.Layer): class DurationPredictor(dg.Layer):
def __init__(self, input_size, out_channels, filter_size, dropout=0.1): def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
"""Duration Predictor block in FastSpeech.
Args:
input_size (int): the channel number of input.
out_channels (int): the output channel number.
filter_size (int): the filter size.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(DurationPredictor, self).__init__() super(DurationPredictor, self).__init__()
self.input_size = input_size self.input_size = input_size
self.out_channels = out_channels self.out_channels = out_channels
...@@ -135,12 +151,13 @@ class DurationPredictor(dg.Layer): ...@@ -135,12 +151,13 @@ class DurationPredictor(dg.Layer):
def forward(self, encoder_output): def forward(self, encoder_output):
""" """
Duration Predictor block in FastSpeech. Predict the duration of each character.
Args: Args:
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output. encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output.
Returns: Returns:
out (Variable), Shape(B, T, C), the output of duration predictor. out (Variable): shape(B, T, C), the output of duration predictor.
""" """
# encoder_output.shape(N, T, C) # encoder_output.shape(N, T, C)
out = layers.transpose(encoder_output, [0, 2, 1]) out = layers.transpose(encoder_output, [0, 2, 1])
......
...@@ -30,16 +30,19 @@ class CBHG(dg.Layer): ...@@ -30,16 +30,19 @@ class CBHG(dg.Layer):
num_gru_layers=2, num_gru_layers=2,
max_pool_kernel_size=2, max_pool_kernel_size=2,
is_post=False): is_post=False):
super(CBHG, self).__init__() """CBHG Module
"""
:param hidden_size: dimension of hidden unit Args:
:param batch_size: batch size hidden_size (int): dimension of hidden unit.
:param K: # of convolution banks batch_size (int): batch size of input.
:param projection_size: dimension of projection unit K (int, optional): number of convolution banks. Defaults to 16.
:param num_gru_layers: # of layers of GRUcell projection_size (int, optional): dimension of projection unit. Defaults to 256.
:param max_pool_kernel_size: max pooling kernel size num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2.
:param is_post: whether post processing or not max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2
is_post (bool, optional): whether post processing or not. Defaults to False.
""" """
super(CBHG, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.projection_size = projection_size self.projection_size = projection_size
self.conv_list = [] self.conv_list = []
...@@ -176,7 +179,15 @@ class CBHG(dg.Layer): ...@@ -176,7 +179,15 @@ class CBHG(dg.Layer):
return x return x
def forward(self, input_): def forward(self, input_):
# input_.shape = [N, C, T] """
Convert linear spectrum to Mel spectrum.
Args:
input_ (Variable): shape(B, C, T), dtype float32, the sequentially input.
Returns:
out (Variable): shape(B, C, T), the CBHG output.
"""
conv_list = [] conv_list = []
conv_input = input_ conv_input = input_
...@@ -217,6 +228,12 @@ class CBHG(dg.Layer): ...@@ -217,6 +228,12 @@ class CBHG(dg.Layer):
class Highwaynet(dg.Layer): class Highwaynet(dg.Layer):
def __init__(self, num_units, num_layers=4): def __init__(self, num_units, num_layers=4):
"""Highway network
Args:
num_units (int): dimension of hidden unit.
num_layers (int, optional): number of highway layers. Defaults to 4.
"""
super(Highwaynet, self).__init__() super(Highwaynet, self).__init__()
self.num_units = num_units self.num_units = num_units
self.num_layers = num_layers self.num_layers = num_layers
...@@ -249,6 +266,15 @@ class Highwaynet(dg.Layer): ...@@ -249,6 +266,15 @@ class Highwaynet(dg.Layer):
self.add_sublayer("gates_{}".format(i), gate) self.add_sublayer("gates_{}".format(i), gate)
def forward(self, input_): def forward(self, input_):
"""
Compute result of Highway network.
Args:
input_(Variable): shape(B, T, C), dtype float32, the sequentially input.
Returns:
out(Variable): the Highway output.
"""
out = input_ out = input_
for linear, gate in zip(self.linears, self.gates): for linear, gate in zip(self.linears, self.gates):
......
...@@ -22,7 +22,15 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet ...@@ -22,7 +22,15 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
class Decoder(dg.Layer): class Decoder(dg.Layer):
def __init__(self, num_hidden, config, num_head=4): def __init__(self, num_hidden, config, num_head=4, n_layers=3):
"""Decoder layer of TransformerTTS.
Args:
num_hidden (int): the number of source vocabulary.
config: the yaml configs used in decoder.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -58,20 +66,20 @@ class Decoder(dg.Layer): ...@@ -58,20 +66,20 @@ class Decoder(dg.Layer):
self.selfattn_layers = [ self.selfattn_layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.selfattn_layers): for i, layer in enumerate(self.selfattn_layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
self.attn_layers = [ self.attn_layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.attn_layers): for i, layer in enumerate(self.attn_layers):
self.add_sublayer("attn_{}".format(i), layer) self.add_sublayer("attn_{}".format(i), layer)
self.ffns = [ self.ffns = [
PositionwiseFeedForward( PositionwiseFeedForward(
num_hidden, num_hidden * num_head, filter_size=1) num_hidden, num_hidden * num_head, filter_size=1)
for _ in range(3) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.ffns): for i, layer in enumerate(self.ffns):
self.add_sublayer("ffns_{}".format(i), layer) self.add_sublayer("ffns_{}".format(i), layer)
...@@ -108,6 +116,28 @@ class Decoder(dg.Layer): ...@@ -108,6 +116,28 @@ class Decoder(dg.Layer):
m_mask=None, m_mask=None,
m_self_mask=None, m_self_mask=None,
zero_mask=None): zero_mask=None):
"""
Compute decoder outputs.
Args:
key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder,
where T_text means the timesteps of input text,
value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder.
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
where T_mel means the timesteps of input spectrum,
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
Returns:
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output.
attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list.
selfattn_list (list[Variable]): len(n_layers), the decoder self attention list.
"""
# get decoder mask with triangular matrix # get decoder mask with triangular matrix
...@@ -121,7 +151,7 @@ class Decoder(dg.Layer): ...@@ -121,7 +151,7 @@ class Decoder(dg.Layer):
else: else:
m_mask, m_self_mask, zero_mask = None, None, None m_mask, m_self_mask, zero_mask = None, None, None
# Decoder pre-network # Decoder pre-network
query = self.decoder_prenet(query) query = self.decoder_prenet(query)
# Centered position # Centered position
......
...@@ -20,7 +20,15 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet ...@@ -20,7 +20,15 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
class Encoder(dg.Layer): class Encoder(dg.Layer):
def __init__(self, embedding_size, num_hidden, num_head=4): def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3):
"""Encoder layer of TransformerTTS.
Args:
embedding_size (int): the size of position embedding.
num_hidden (int): the size of hidden layer in network.
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
num_head (int, optional): the head number of multihead attention. Defaults to 3.
"""
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -42,7 +50,7 @@ class Encoder(dg.Layer): ...@@ -42,7 +50,7 @@ class Encoder(dg.Layer):
use_cudnn=True) use_cudnn=True)
self.layers = [ self.layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
...@@ -51,12 +59,26 @@ class Encoder(dg.Layer): ...@@ -51,12 +59,26 @@ class Encoder(dg.Layer):
num_hidden, num_hidden,
num_hidden * num_head, num_hidden * num_head,
filter_size=1, filter_size=1,
use_cudnn=True) for _ in range(3) use_cudnn=True) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.ffns): for i, layer in enumerate(self.ffns):
self.add_sublayer("ffns_{}".format(i), layer) self.add_sublayer("ffns_{}".format(i), layer)
def forward(self, x, positional, mask=None, query_mask=None): def forward(self, x, positional, mask=None, query_mask=None):
"""
Encode text sequence.
Args:
x (Variable): shape(B, T_text), dtype float32, the input character,
where T_text means the timesteps of input text,
positional (Variable): shape(B, T_text), dtype int64, the characters position.
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
Returns:
x (Variable): shape(B, T_text, C), the encoder output.
attentions (list[Variable]): len(n_layers), the encoder self attention list.
"""
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
seq_len_key = x.shape[1] seq_len_key = x.shape[1]
...@@ -66,12 +88,12 @@ class Encoder(dg.Layer): ...@@ -66,12 +88,12 @@ class Encoder(dg.Layer):
else: else:
query_mask, mask = None, None query_mask, mask = None, None
# Encoder pre_network # Encoder pre_network
x = self.encoder_prenet(x) #(N,T,C) x = self.encoder_prenet(x)
# Get positional encoding # Get positional encoding
positional = self.pos_emb(positional) positional = self.pos_emb(positional)
x = positional * self.alpha + x #(N, T, C) x = positional * self.alpha + x
# Positional dropout # Positional dropout
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train') x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
......
...@@ -22,6 +22,13 @@ import numpy as np ...@@ -22,6 +22,13 @@ import numpy as np
class EncoderPrenet(dg.Layer): class EncoderPrenet(dg.Layer):
def __init__(self, embedding_size, num_hidden, use_cudnn=True): def __init__(self, embedding_size, num_hidden, use_cudnn=True):
""" Encoder prenet layer of TransformerTTS.
Args:
embedding_size (int): the size of embedding.
num_hidden (int): the size of hidden layer in network.
use_cudnn (bool, optional): use cudnn or not. Defaults to True.
"""
super(EncoderPrenet, self).__init__() super(EncoderPrenet, self).__init__()
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.num_hidden = num_hidden self.num_hidden = num_hidden
...@@ -81,8 +88,17 @@ class EncoderPrenet(dg.Layer): ...@@ -81,8 +88,17 @@ class EncoderPrenet(dg.Layer):
low=-k, high=k))) low=-k, high=k)))
def forward(self, x): def forward(self, x):
"""
Prepare encoder input.
Args:
x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text.
Returns:
(Variable): shape(B, T_text, C), the encoder prenet output.
"""
x = self.embedding(x) #(batch_size, seq_len, embending_size) x = self.embedding(x)
x = layers.transpose(x, [0, 2, 1]) x = layers.transpose(x, [0, 2, 1])
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list): for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
x = layers.dropout( x = layers.dropout(
......
...@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer): ...@@ -29,6 +29,19 @@ class PostConvNet(dg.Layer):
use_cudnn=True, use_cudnn=True,
dropout=0.1, dropout=0.1,
batchnorm_last=False): batchnorm_last=False):
"""Decocder post conv net of TransformerTTS.
Args:
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
num_hidden (int, optional): the size of hidden layer in network. Defaults to 512.
filter_size (int, optional): the filter size of Conv. Defaults to 5.
padding (int, optional): the padding size of Conv. Defaults to 0.
num_conv (int, optional): the num of Conv layers in network. Defaults to 5.
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False.
"""
super(PostConvNet, self).__init__() super(PostConvNet, self).__init__()
self.dropout = dropout self.dropout = dropout
...@@ -93,12 +106,13 @@ class PostConvNet(dg.Layer): ...@@ -93,12 +106,13 @@ class PostConvNet(dg.Layer):
def forward(self, input): def forward(self, input):
""" """
Post Conv Net. Compute the mel spectrum.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): shape(B, T, C), dtype float32, the result of mel linear projection.
Returns: Returns:
output (Variable), Shape(B, T, C), the result after postconvnet. output (Variable): shape(B, T, C), the result after postconvnet.
""" """
input = layers.transpose(input, [0, 2, 1]) input = layers.transpose(input, [0, 2, 1])
......
...@@ -19,10 +19,13 @@ import paddle.fluid.layers as layers ...@@ -19,10 +19,13 @@ import paddle.fluid.layers as layers
class PreNet(dg.Layer): class PreNet(dg.Layer):
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
""" """Prenet before passing through the network.
:param input_size: dimension of input
:param hidden_size: dimension of hidden unit Args:
:param output_size: dimension of output input_size (int): the input channel size.
hidden_size (int): the size of hidden layer in network.
output_size (int): the output channel size.
dropout_rate (float, optional): dropout probability. Defaults to 0.2.
""" """
super(PreNet, self).__init__() super(PreNet, self).__init__()
self.input_size = input_size self.input_size = input_size
...@@ -49,19 +52,20 @@ class PreNet(dg.Layer): ...@@ -49,19 +52,20 @@ class PreNet(dg.Layer):
def forward(self, x): def forward(self, x):
""" """
Pre Net before passing through the network. Prepare network input.
Args: Args:
x (Variable): Shape(B, T, C), dtype: float32. The input value. x (Variable): shape(B, T, C), dtype float32, the input value.
Returns: Returns:
x (Variable), Shape(B, T, C), the result after pernet. output (Variable): shape(B, T, C), the result after pernet.
""" """
x = layers.dropout( x = layers.dropout(
layers.relu(self.linear1(x)), layers.relu(self.linear1(x)),
self.dropout_rate, self.dropout_rate,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
x = layers.dropout( output = layers.dropout(
layers.relu(self.linear2(x)), layers.relu(self.linear2(x)),
self.dropout_rate, self.dropout_rate,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
return x return output
...@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder ...@@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder
class TransformerTTS(dg.Layer): class TransformerTTS(dg.Layer):
def __init__(self, config): def __init__(self, config):
"""TransformerTTS model.
Args:
config: the yaml configs used in TransformerTTS model.
"""
super(TransformerTTS, self).__init__() super(TransformerTTS, self).__init__()
self.encoder = Encoder(config['embedding_size'], config['hidden_size']) self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
self.decoder = Decoder(config['hidden_size'], config) self.decoder = Decoder(config['hidden_size'], config)
...@@ -35,6 +40,31 @@ class TransformerTTS(dg.Layer): ...@@ -35,6 +40,31 @@ class TransformerTTS(dg.Layer):
enc_dec_mask=None, enc_dec_mask=None,
dec_query_slf_mask=None, dec_query_slf_mask=None,
dec_query_mask=None): dec_query_mask=None):
"""
TransformerTTS network.
Args:
characters (Variable): shape(B, T_text), dtype float32, the input character,
where T_text means the timesteps of input text,
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
where T_mel means the timesteps of input spectrum,
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
Returns:
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network.
stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output.
attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list.
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
"""
key, attns_enc = self.encoder( key, attns_enc = self.encoder(
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
...@@ -48,5 +78,3 @@ class TransformerTTS(dg.Layer): ...@@ -48,5 +78,3 @@ class TransformerTTS(dg.Layer):
m_self_mask=dec_query_slf_mask, m_self_mask=dec_query_slf_mask,
m_mask=dec_query_mask) m_mask=dec_query_mask)
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
...@@ -19,11 +19,13 @@ from parakeet.models.transformer_tts.cbhg import CBHG ...@@ -19,11 +19,13 @@ from parakeet.models.transformer_tts.cbhg import CBHG
class Vocoder(dg.Layer): class Vocoder(dg.Layer):
"""
CBHG Network (mel -> linear)
"""
def __init__(self, config, batch_size): def __init__(self, config, batch_size):
"""CBHG Network (mel -> linear)
Args:
config: the yaml configs used in Vocoder model.
batch_size (int): the batch size of input.
"""
super(Vocoder, self).__init__() super(Vocoder, self).__init__()
self.pre_proj = Conv1D( self.pre_proj = Conv1D(
num_channels=config['audio']['num_mels'], num_channels=config['audio']['num_mels'],
...@@ -36,6 +38,15 @@ class Vocoder(dg.Layer): ...@@ -36,6 +38,15 @@ class Vocoder(dg.Layer):
filter_size=1) filter_size=1)
def forward(self, mel): def forward(self, mel):
"""
Compute mel spectrum to linear spectrum.
Args:
mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum.
Returns:
mag_pred (Variable): shape(B, T, C), the linear output.
"""
mel = layers.transpose(mel, [0, 2, 1]) mel = layers.transpose(mel, [0, 2, 1])
mel = self.pre_proj(mel) mel = self.pre_proj(mel)
mel = self.cbhg(mel) mel = self.cbhg(mel)
......
...@@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer): ...@@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer):
Dynamic GRU block. Dynamic GRU block.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): shape(B, T, C), dtype float32, the input value.
Returns: Returns:
output (Variable), Shape(B, T, C), the result compute by GRU. output (Variable): shape(B, T, C), the result compute by GRU.
""" """
hidden = self.h_0 hidden = self.h_0
res = [] res = []
......
...@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D ...@@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D
class PositionwiseFeedForward(dg.Layer): class PositionwiseFeedForward(dg.Layer):
''' A two-feed-forward-layer module '''
def __init__(self, def __init__(self,
d_in, d_in,
num_hidden, num_hidden,
...@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer): ...@@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer):
padding=0, padding=0,
use_cudnn=True, use_cudnn=True,
dropout=0.1): dropout=0.1):
"""A two-feed-forward-layer module.
Args:
d_in (int): the size of input channel.
num_hidden (int): the size of hidden layer in network.
filter_size (int): the filter size of Conv
padding (int, optional): the padding size of Conv. Defaults to 0.
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.1.
"""
super(PositionwiseFeedForward, self).__init__() super(PositionwiseFeedForward, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
...@@ -59,12 +67,13 @@ class PositionwiseFeedForward(dg.Layer): ...@@ -59,12 +67,13 @@ class PositionwiseFeedForward(dg.Layer):
def forward(self, input): def forward(self, input):
""" """
Feed Forward Network. Compute feed forward network result.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): shape(B, T, C), dtype float32, the input value.
Returns: Returns:
output (Variable), Shape(B, T, C), the result after FFN. output (Variable): shape(B, T, C), the result after FFN.
""" """
x = layers.transpose(input, [0, 2, 1]) x = layers.transpose(input, [0, 2, 1])
#FFN Networt #FFN Networt
......
此差异已折叠。
...@@ -50,6 +50,11 @@ class Linear(dg.Layer): ...@@ -50,6 +50,11 @@ class Linear(dg.Layer):
class ScaledDotProductAttention(dg.Layer): class ScaledDotProductAttention(dg.Layer):
def __init__(self, d_key): def __init__(self, d_key):
"""Scaled dot product attention module.
Args:
d_key (int): the dim of key in multihead attention.
"""
super(ScaledDotProductAttention, self).__init__() super(ScaledDotProductAttention, self).__init__()
self.d_key = d_key self.d_key = d_key
...@@ -63,18 +68,18 @@ class ScaledDotProductAttention(dg.Layer): ...@@ -63,18 +68,18 @@ class ScaledDotProductAttention(dg.Layer):
query_mask=None, query_mask=None,
dropout=0.1): dropout=0.1):
""" """
Scaled Dot Product Attention. Compute scaled dot product attention.
Args: Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention. query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
dropout (Constant): dtype: float32. The probability of dropout. dropout (float32, optional): the probability of dropout. Defaults to 0.1.
Returns: Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention. result (Variable): shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key. attention (Variable): shape(n_head * B, T, C), the attention of key.
""" """
# Compute attention score # Compute attention score
attention = layers.matmul( attention = layers.matmul(
...@@ -105,6 +110,17 @@ class MultiheadAttention(dg.Layer): ...@@ -105,6 +110,17 @@ class MultiheadAttention(dg.Layer):
is_bias=False, is_bias=False,
dropout=0.1, dropout=0.1,
is_concat=True): is_concat=True):
"""Multihead Attention.
Args:
num_hidden (int): the number of hidden layer in network.
d_k (int): the dim of key in multihead attention.
d_q (int): the dim of query in multihead attention.
num_head (int, optional): the head number of multihead attention. Defaults to 4.
is_bias (bool, optional): whether have bias in linear layers. Default to False.
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
is_concat (bool, optional): whether concat query and result. Default to True.
"""
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
...@@ -128,17 +144,18 @@ class MultiheadAttention(dg.Layer): ...@@ -128,17 +144,18 @@ class MultiheadAttention(dg.Layer):
def forward(self, key, value, query_input, mask=None, query_mask=None): def forward(self, key, value, query_input, mask=None, query_mask=None):
""" """
Multihead Attention. Compute attention.
Args: Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. key (Variable): shape(B, T, C), dtype float32, the input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. value (Variable): shape(B, T, C), dtype float32, the input value of attention.
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention. query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
Returns: Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention. result (Variable): shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key. attention (Variable): shape(num_head * B, T, C), the attention of key and query.
""" """
batch_size = key.shape[0] batch_size = key.shape[0]
...@@ -146,7 +163,6 @@ class MultiheadAttention(dg.Layer): ...@@ -146,7 +163,6 @@ class MultiheadAttention(dg.Layer):
seq_len_query = query_input.shape[1] seq_len_query = query_input.shape[1]
# Make multihead attention # Make multihead attention
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
key = layers.reshape( key = layers.reshape(
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k]) self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
value = layers.reshape( value = layers.reshape(
...@@ -168,18 +184,6 @@ class MultiheadAttention(dg.Layer): ...@@ -168,18 +184,6 @@ class MultiheadAttention(dg.Layer):
result, attention = self.scal_attn( result, attention = self.scal_attn(
key, value, query, mask=mask, query_mask=query_mask) key, value, query, mask=mask, query_mask=query_mask)
key = layers.reshape(
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
value = layers.reshape(
layers.transpose(value, [2, 0, 1, 3]),
[-1, seq_len_key, self.d_k])
query = layers.reshape(
layers.transpose(query, [2, 0, 1, 3]),
[-1, seq_len_query, self.d_q])
result, attention = self.scal_attn(
key, value, query, mask=mask, query_mask=query_mask)
# concat all multihead result # concat all multihead result
result = layers.reshape( result = layers.reshape(
result, [self.num_head, batch_size, seq_len_query, self.d_q]) result, [self.num_head, batch_size, seq_len_query, self.d_q])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册