提交 ad4b248a 编写于 作者: L lifuchen

fix some bug of mask in fastspeech

上级 75d46422
...@@ -18,6 +18,7 @@ import argparse ...@@ -18,6 +18,7 @@ import argparse
from parse import add_config_options_to_parser from parse import add_config_options_to_parser
from pprint import pprint from pprint import pprint
from ruamel import yaml from ruamel import yaml
from matplotlib import cm
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
...@@ -64,8 +65,7 @@ def synthesis(text_input, args): ...@@ -64,8 +65,7 @@ def synthesis(text_input, args):
pos_text = np.arange(1, text.shape[1] + 1) pos_text = np.arange(1, text.shape[1] + 1)
pos_text = np.expand_dims(pos_text, axis=0) pos_text = np.expand_dims(pos_text, axis=0)
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32) enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text, enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
text).astype(np.float32)
text = dg.to_variable(text) text = dg.to_variable(text)
pos_text = dg.to_variable(pos_text) pos_text = dg.to_variable(pos_text)
...@@ -101,8 +101,17 @@ def synthesis(text_input, args): ...@@ -101,8 +101,17 @@ def synthesis(text_input, args):
do_trim_silence=False, do_trim_silence=False,
sound_norm=False) sound_norm=False)
np.save('mel_output', mel_output_postnet.numpy())
mel_output_postnet = fluid.layers.transpose( mel_output_postnet = fluid.layers.transpose(
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0]) fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
x = np.uint8(cm.viridis(mel_output_postnet.numpy()) * 255)
writer.add_image('mel_0_0', x, 0, dataformats="HWC")
ground_truth = _ljspeech_processor.load_wav(
str('/paddle/Parakeet/dataset/LJSpeech-1.1/wavs/LJ001-0175.wav'))
ground_truth = _ljspeech_processor.melspectrogram(ground_truth).astype(
np.float32)
x = np.uint8(cm.viridis(ground_truth) * 255)
writer.add_image('mel_gt_0', x, 0, dataformats="HWC")
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy( wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy(
)) ))
writer.add_audio(text_input, wav, 0, cfg['audio']['sr']) writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
...@@ -114,4 +123,5 @@ if __name__ == '__main__': ...@@ -114,4 +123,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train Fastspeech model") parser = argparse.ArgumentParser(description="Train Fastspeech model")
add_config_options_to_parser(parser) add_config_options_to_parser(parser)
args = parser.parse_args() args = parser.parse_args()
synthesis("Transformer model is so fast!", args) synthesis("Simple as this proposition is, it is necessary to be stated,",
args)
...@@ -4,7 +4,7 @@ python -u synthesis.py \ ...@@ -4,7 +4,7 @@ python -u synthesis.py \
--use_gpu=1 \ --use_gpu=1 \
--alpha=1.0 \ --alpha=1.0 \
--checkpoint_path='checkpoint/' \ --checkpoint_path='checkpoint/' \
--fastspeech_step=71000 \ --fastspeech_step=89000 \
--log_dir='./log' \ --log_dir='./log' \
--config_path='configs/synthesis.yaml' \ --config_path='configs/synthesis.yaml' \
......
...@@ -88,7 +88,8 @@ class Decoder(dg.Layer): ...@@ -88,7 +88,8 @@ class Decoder(dg.Layer):
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list. dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
""" """
dec_slf_attn_list = [] dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) if slf_attn_mask:
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
# -- Forward # -- Forward
dec_output = enc_seq + self.position_enc(enc_pos) dec_output = enc_seq + self.position_enc(enc_pos)
......
...@@ -142,6 +142,7 @@ class FastSpeech(dg.Layer): ...@@ -142,6 +142,7 @@ class FastSpeech(dg.Layer):
encoder_output, alpha=alpha) encoder_output, alpha=alpha)
slf_attn_mask = get_triu_tensor( slf_attn_mask = get_triu_tensor(
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32) decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
slf_attn_mask = fluid.layers.cast( slf_attn_mask = fluid.layers.cast(
dg.to_variable(slf_attn_mask == 0), np.float32) dg.to_variable(slf_attn_mask == 0), np.float32)
slf_attn_mask = dg.to_variable(slf_attn_mask) slf_attn_mask = dg.to_variable(slf_attn_mask)
......
...@@ -149,6 +149,7 @@ class Decoder(dg.Layer): ...@@ -149,6 +149,7 @@ class Decoder(dg.Layer):
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1]) zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
else: else:
mask = layers.expand(mask, [self.num_head, 1, 1])
m_mask, m_self_mask, zero_mask = None, None, None m_mask, m_self_mask, zero_mask = None, None, None
# Decoder pre-network # Decoder pre-network
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册