提交 41590a59 编写于 作者: L liuyibing01

Merge branch 'master' into 'master'

modified the process of generating masks to speed up batching

See merge request !46
...@@ -18,7 +18,7 @@ def add_config_options_to_parser(parser): ...@@ -18,7 +18,7 @@ def add_config_options_to_parser(parser):
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
type=str, type=str,
default='config/fastspeech.yaml', default='configs/fastspeech.yaml',
help="the yaml config file path.") help="the yaml config file path.")
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=32, help="batch size for training.") '--batch_size', type=int, default=32, help="batch size for training.")
...@@ -87,7 +87,7 @@ def add_config_options_to_parser(parser): ...@@ -87,7 +87,7 @@ def add_config_options_to_parser(parser):
parser.add_argument( parser.add_argument(
'--transtts_path', '--transtts_path',
type=str, type=str,
default='./log', default='../transformer_tts/checkpoint',
help="the directory to load pretrain transformerTTS model.") help="the directory to load pretrain transformerTTS model.")
parser.add_argument( parser.add_argument(
'--transformer_step', '--transformer_step',
......
...@@ -10,7 +10,7 @@ python -u train.py \ ...@@ -10,7 +10,7 @@ python -u train.py \
--use_data_parallel=0 \ --use_data_parallel=0 \
--data_path='../../dataset/LJSpeech-1.1' \ --data_path='../../dataset/LJSpeech-1.1' \
--transtts_path='../transformer_tts/checkpoint' \ --transtts_path='../transformer_tts/checkpoint' \
--transformer_step=160000 \ --transformer_step=120000 \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='configs/fastspeech.yaml' \ --config_path='configs/fastspeech.yaml' \
......
...@@ -55,6 +55,8 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr ...@@ -55,6 +55,8 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--transformer_step``. If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--transformer_step``.
**Note: In order to ensure the training effect, we recommend using multi-GPU training to enlarge the batch size, and at least 16 samples in single batch per GPU.**
For more help on arguments: For more help on arguments:
``python train_transformer.py --help``. ``python train_transformer.py --help``.
......
...@@ -23,7 +23,7 @@ from parakeet import audio ...@@ -23,7 +23,7 @@ from parakeet import audio
from parakeet.data.sampler import * from parakeet.data.sampler import *
from parakeet.data.datacargo import DataCargo from parakeet.data.datacargo import DataCargo
from parakeet.data.batch import TextIDBatcher, SpecBatcher from parakeet.data.batch import TextIDBatcher, SpecBatcher
from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset, SliceDataset
from parakeet.models.transformer_tts.utils import * from parakeet.models.transformer_tts.utils import *
...@@ -44,7 +44,7 @@ class LJSpeechLoader: ...@@ -44,7 +44,7 @@ class LJSpeechLoader:
dataset = CacheDataset(dataset) dataset = CacheDataset(dataset)
sampler = DistributedSampler( sampler = DistributedSampler(
len(metadata), nranks, rank, shuffle=shuffle) len(dataset), nranks, rank, shuffle=shuffle)
assert args.batch_size % nranks == 0 assert args.batch_size % nranks == 0
each_bs = args.batch_size // nranks each_bs = args.batch_size // nranks
...@@ -64,7 +64,6 @@ class LJSpeechLoader: ...@@ -64,7 +64,6 @@ class LJSpeechLoader:
shuffle=shuffle, shuffle=shuffle,
batch_fn=batch_examples, batch_fn=batch_examples,
drop_last=True) drop_last=True)
self.reader = fluid.io.DataLoader.from_generator( self.reader = fluid.io.DataLoader.from_generator(
capacity=32, capacity=32,
iterable=True, iterable=True,
...@@ -199,12 +198,13 @@ def batch_examples(batch): ...@@ -199,12 +198,13 @@ def batch_examples(batch):
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels) SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
mel_inputs = np.transpose( mel_inputs = np.transpose(
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels) SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
enc_slf_mask = get_attn_key_pad_mask(pos_texts, texts).astype(np.float32)
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32) enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,
mel_inputs).astype(np.float32) mel_inputs).astype(np.float32)
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0], enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0]).astype(
mel_inputs).astype(np.float32) np.float32)
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32) dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32) dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
......
...@@ -18,7 +18,7 @@ def add_config_options_to_parser(parser): ...@@ -18,7 +18,7 @@ def add_config_options_to_parser(parser):
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
type=str, type=str,
default='config/train_transformer.yaml', default='configs/train_transformer.yaml',
help="the yaml config file path.") help="the yaml config file path.")
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=32, help="batch size for training.") '--batch_size', type=int, default=32, help="batch size for training.")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
from tqdm import tqdm from tqdm import tqdm
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
#from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
import argparse import argparse
from parse import add_config_options_to_parser from parse import add_config_options_to_parser
...@@ -69,9 +68,6 @@ def main(args): ...@@ -69,9 +68,6 @@ def main(args):
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']), cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
parameter_list=model.parameters()) parameter_list=model.parameters())
reader = LJSpeechLoader(
cfg, args, nranks, local_rank, shuffle=True).reader()
if args.checkpoint_path is not None: if args.checkpoint_path is not None:
model_dict, opti_dict = load_checkpoint( model_dict, opti_dict = load_checkpoint(
str(args.transformer_step), str(args.transformer_step),
...@@ -85,6 +81,9 @@ def main(args): ...@@ -85,6 +81,9 @@ def main(args):
strategy = dg.parallel.prepare_context() strategy = dg.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
reader = LJSpeechLoader(
cfg, args, nranks, local_rank, shuffle=True).reader()
for epoch in range(args.epochs): for epoch in range(args.epochs):
pbar = tqdm(reader) pbar = tqdm(reader)
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
...@@ -148,7 +147,8 @@ def main(args): ...@@ -148,7 +147,8 @@ def main(args):
for i, prob in enumerate(attn_probs): for i, prob in enumerate(attn_probs):
for j in range(4): for j in range(4):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * 16]) * 255) cm.viridis(prob.numpy()[j * args.batch_size
// 2]) * 255)
writer.add_image( writer.add_image(
'Attention_%d_0' % global_step, 'Attention_%d_0' % global_step,
x, x,
...@@ -158,7 +158,8 @@ def main(args): ...@@ -158,7 +158,8 @@ def main(args):
for i, prob in enumerate(attn_enc): for i, prob in enumerate(attn_enc):
for j in range(4): for j in range(4):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * 16]) * 255) cm.viridis(prob.numpy()[j * args.batch_size
// 2]) * 255)
writer.add_image( writer.add_image(
'Attention_enc_%d_0' % global_step, 'Attention_enc_%d_0' % global_step,
x, x,
...@@ -168,7 +169,8 @@ def main(args): ...@@ -168,7 +169,8 @@ def main(args):
for i, prob in enumerate(attn_dec): for i, prob in enumerate(attn_dec):
for j in range(4): for j in range(4):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * 16]) * 255) cm.viridis(prob.numpy()[j * args.batch_size
// 2]) * 255)
writer.add_image( writer.add_image(
'Attention_dec_%d_0' % global_step, 'Attention_dec_%d_0' % global_step,
x, x,
......
...@@ -56,15 +56,13 @@ def get_non_pad_mask(seq): ...@@ -56,15 +56,13 @@ def get_non_pad_mask(seq):
return mask return mask
def get_attn_key_pad_mask(seq_k, seq_q): def get_attn_key_pad_mask(seq_k):
''' For masking out the padding part of key sequence. ''' ''' For masking out the padding part of key sequence. '''
# Expand to fit the shape of key query attention matrix. # Expand to fit the shape of key query attention matrix.
len_q = seq_q.shape[1]
padding_mask = (seq_k != 0).astype(np.float32) padding_mask = (seq_k != 0).astype(np.float32)
padding_mask = np.expand_dims(padding_mask, axis=1) padding_mask = np.expand_dims(padding_mask, axis=1)
padding_mask = padding_mask.repeat([len_q], axis=1) padding_mask = (
padding_mask = (padding_mask == 0).astype(np.float32) * (-2**32 + 1) padding_mask == 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
return padding_mask return padding_mask
...@@ -72,12 +70,12 @@ def get_dec_attn_key_pad_mask(seq_k, seq_q): ...@@ -72,12 +70,12 @@ def get_dec_attn_key_pad_mask(seq_k, seq_q):
''' For masking out the padding part of key sequence. ''' ''' For masking out the padding part of key sequence. '''
# Expand to fit the shape of key query attention matrix. # Expand to fit the shape of key query attention matrix.
len_q = seq_q.shape[1]
padding_mask = (seq_k == 0).astype(np.float32) padding_mask = (seq_k == 0).astype(np.float32)
padding_mask = np.expand_dims(padding_mask, axis=1) padding_mask = np.expand_dims(padding_mask, axis=1)
triu_tensor = get_triu_tensor(seq_q, seq_q) triu_tensor = get_triu_tensor(seq_q, seq_q)
padding_mask = padding_mask.repeat([len_q], axis=1) + triu_tensor padding_mask = padding_mask + triu_tensor
padding_mask = (padding_mask != 0).astype(np.float32) * (-2**32 + 1) padding_mask = (
padding_mask != 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
return padding_mask return padding_mask
...@@ -85,12 +83,7 @@ def get_triu_tensor(seq_k, seq_q): ...@@ -85,12 +83,7 @@ def get_triu_tensor(seq_k, seq_q):
''' For make a triu tensor ''' ''' For make a triu tensor '''
len_k = seq_k.shape[1] len_k = seq_k.shape[1]
len_q = seq_q.shape[1] len_q = seq_q.shape[1]
batch_size = seq_k.shape[0]
triu_tensor = np.triu(np.ones([len_k, len_q]), 1) triu_tensor = np.triu(np.ones([len_k, len_q]), 1)
triu_tensor = np.repeat(
np.expand_dims(
triu_tensor, axis=0), batch_size, axis=0)
return triu_tensor return triu_tensor
......
...@@ -89,7 +89,7 @@ class ScaledDotProductAttention(dg.Layer): ...@@ -89,7 +89,7 @@ class ScaledDotProductAttention(dg.Layer):
# Mask key to ignore padding # Mask key to ignore padding
if mask is not None: if mask is not None:
attention = attention + mask attention = attention + mask
attention = layers.softmax(attention) attention = layers.softmax(attention, use_cudnn=True)
attention = layers.dropout( attention = layers.dropout(
attention, dropout, dropout_implementation='upscale_in_train') attention, dropout, dropout_implementation='upscale_in_train')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册