提交 ad4b248a 编写于 作者: L lifuchen

fix some bug of mask in fastspeech

上级 75d46422
......@@ -18,6 +18,7 @@ import argparse
from parse import add_config_options_to_parser
from pprint import pprint
from ruamel import yaml
from matplotlib import cm
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
......@@ -64,8 +65,7 @@ def synthesis(text_input, args):
pos_text = np.arange(1, text.shape[1] + 1)
pos_text = np.expand_dims(pos_text, axis=0)
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text,
text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
text = dg.to_variable(text)
pos_text = dg.to_variable(pos_text)
......@@ -101,8 +101,17 @@ def synthesis(text_input, args):
do_trim_silence=False,
sound_norm=False)
np.save('mel_output', mel_output_postnet.numpy())
mel_output_postnet = fluid.layers.transpose(
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
x = np.uint8(cm.viridis(mel_output_postnet.numpy()) * 255)
writer.add_image('mel_0_0', x, 0, dataformats="HWC")
ground_truth = _ljspeech_processor.load_wav(
str('/paddle/Parakeet/dataset/LJSpeech-1.1/wavs/LJ001-0175.wav'))
ground_truth = _ljspeech_processor.melspectrogram(ground_truth).astype(
np.float32)
x = np.uint8(cm.viridis(ground_truth) * 255)
writer.add_image('mel_gt_0', x, 0, dataformats="HWC")
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy(
))
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
......@@ -114,4 +123,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train Fastspeech model")
add_config_options_to_parser(parser)
args = parser.parse_args()
synthesis("Transformer model is so fast!", args)
synthesis("Simple as this proposition is, it is necessary to be stated,",
args)
......@@ -4,7 +4,7 @@ python -u synthesis.py \
--use_gpu=1 \
--alpha=1.0 \
--checkpoint_path='checkpoint/' \
--fastspeech_step=71000 \
--fastspeech_step=89000 \
--log_dir='./log' \
--config_path='configs/synthesis.yaml' \
......
......@@ -88,7 +88,8 @@ class Decoder(dg.Layer):
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
"""
dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
if slf_attn_mask:
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
# -- Forward
dec_output = enc_seq + self.position_enc(enc_pos)
......
......@@ -142,6 +142,7 @@ class FastSpeech(dg.Layer):
encoder_output, alpha=alpha)
slf_attn_mask = get_triu_tensor(
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
slf_attn_mask = fluid.layers.cast(
dg.to_variable(slf_attn_mask == 0), np.float32)
slf_attn_mask = dg.to_variable(slf_attn_mask)
......
......@@ -149,6 +149,7 @@ class Decoder(dg.Layer):
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
else:
mask = layers.expand(mask, [self.num_head, 1, 1])
m_mask, m_self_mask, zero_mask = None, None, None
# Decoder pre-network
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册