import os import argparse import ruamel.yaml import numpy as np from matplotlib import cm import matplotlib.pyplot as plt import tqdm import librosa from librosa import display import soundfile as sf from tensorboardX import SummaryWriter from paddle import fluid import paddle.fluid.layers as F import paddle.fluid.dygraph as dg from parakeet.g2p import en from parakeet.data import FilterDataset, TransformDataset, FilterDataset from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.utils.layer_tools import summary from data import LJSpeechMetaData, DataCollector, Transform from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a deepvoice 3 model with LJSpeech dataset.") parser.add_argument("-c", "--config", type=str, help="experimrnt config") parser.add_argument("-s", "--data", type=str, default="/workspace/datasets/LJSpeech-1.1/", help="The path of the LJSpeech dataset.") parser.add_argument("-r", "--resume", type=str, help="checkpoint to load") parser.add_argument("-o", "--output", type=str, default="result", help="The directory to save result.") parser.add_argument("-g", "--device", type=int, default=-1, help="device to use") args, _ = parser.parse_known_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) # =========================dataset========================= # construct meta data data_root = args.data meta = LJSpeechMetaData(data_root) # filter it! min_text_length = config["meta_data"]["min_text_length"] meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length) # transform meta data into meta data transform_config = config["transform"] replace_pronounciation_prob = transform_config[ "replace_pronunciation_prob"] sample_rate = transform_config["sample_rate"] preemphasis = transform_config["preemphasis"] n_fft = transform_config["n_fft"] win_length = transform_config["win_length"] hop_length = transform_config["hop_length"] fmin = transform_config["fmin"] fmax = transform_config["fmax"] n_mels = transform_config["n_mels"] min_level_db = transform_config["min_level_db"] ref_level_db = transform_config["ref_level_db"] max_norm = transform_config["max_norm"] clip_norm = transform_config["clip_norm"] transform = Transform(replace_pronounciation_prob, sample_rate, preemphasis, n_fft, win_length, hop_length, fmin, fmax, n_mels, min_level_db, ref_level_db, max_norm, clip_norm) ljspeech = TransformDataset(meta, transform) # =========================dataiterator========================= # use meta data's text length as a sort key for the sampler train_config = config["train"] batch_size = train_config["batch_size"] text_lengths = [len(example[2]) for example in meta] sampler = PartialyRandomizedSimilarTimeLengthSampler( text_lengths, batch_size) # some hyperparameters affect how we process data, so create a data collector! model_config = config["model"] downsample_factor = model_config["downsample_factor"] r = model_config["outputs_per_step"] collector = DataCollector(downsample_factor=downsample_factor, r=r) ljspeech_loader = DataCargo(ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler) # =========================model========================= if args.device == -1: place = fluid.CPUPlace() else: place = fluid.CUDAPlace(args.device) with dg.guard(place): # =========================model========================= n_speakers = model_config["n_speakers"] speaker_dim = model_config["speaker_embed_dim"] speaker_embed_std = model_config["speaker_embedding_weight_std"] n_vocab = en.n_vocab embed_dim = model_config["text_embed_dim"] linear_dim = 1 + n_fft // 2 use_decoder_states = model_config[ "use_decoder_state_for_postnet_input"] filter_size = model_config["kernel_size"] encoder_channels = model_config["encoder_channels"] decoder_channels = model_config["decoder_channels"] converter_channels = model_config["converter_channels"] dropout = model_config["dropout"] padding_idx = model_config["padding_idx"] embedding_std = model_config["embedding_weight_std"] max_positions = model_config["max_positions"] freeze_embedding = model_config["freeze_embedding"] trainable_positional_encodings = model_config[ "trainable_positional_encodings"] use_memory_mask = model_config["use_memory_mask"] query_position_rate = model_config["query_position_rate"] key_position_rate = model_config["key_position_rate"] window_backward = model_config["window_backward"] window_ahead = model_config["window_ahead"] key_projection = model_config["key_projection"] value_projection = model_config["value_projection"] dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx, embedding_std, max_positions, n_vocab, freeze_embedding, filter_size, encoder_channels, n_mels, decoder_channels, r, trainable_positional_encodings, use_memory_mask, query_position_rate, key_position_rate, window_backward, window_ahead, key_projection, value_projection, downsample_factor, linear_dim, use_decoder_states, converter_channels, dropout) # =========================loss========================= loss_config = config["loss"] masked_weight = loss_config["masked_loss_weight"] priority_freq = loss_config["priority_freq"] # Hz priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim) priority_freq_weight = loss_config["priority_freq_weight"] binary_divergence_weight = loss_config["binary_divergence_weight"] guided_attention_sigma = loss_config["guided_attention_sigma"] criterion = TTSLoss(masked_weight=masked_weight, priority_bin=priority_bin, priority_weight=priority_freq_weight, binary_divergence_weight=binary_divergence_weight, guided_attention_sigma=guided_attention_sigma, downsample_factor=downsample_factor, r=r) # =========================lr_scheduler========================= lr_config = config["lr_scheduler"] warmup_steps = lr_config["warmup_steps"] peak_learning_rate = lr_config["peak_learning_rate"] lr_scheduler = dg.NoamDecay( 1 / (warmup_steps * (peak_learning_rate)**2), warmup_steps) # =========================optimizer========================= optim_config = config["optimizer"] beta1 = optim_config["beta1"] beta2 = optim_config["beta2"] epsilon = optim_config["epsilon"] optim = fluid.optimizer.Adam(lr_scheduler, beta1, beta2, epsilon=epsilon, parameter_list=dv3.parameters()) gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1) # generation synthesis_config = config["synthesis"] power = synthesis_config["power"] n_iter = synthesis_config["n_iter"] # =========================link(dataloader, paddle)========================= # CAUTION: it does not return a DataLoader loader = fluid.io.DataLoader.from_generator(capacity=10, return_list=True) loader.set_batch_generator(ljspeech_loader, places=place) # tensorboard & checkpoint preparation output_dir = args.output ckpt_dir = os.path.join(output_dir, "checkpoints") log_dir = os.path.join(output_dir, "log") state_dir = os.path.join(output_dir, "states") make_output_tree(output_dir) writer = SummaryWriter(logdir=log_dir) # load model parameters resume_path = args.resume if resume_path is not None: state, _ = dg.load_dygraph(args.resume) dv3.set_dict(state) # =========================train========================= epoch = train_config["epochs"] snap_interval = train_config["snap_interval"] save_interval = train_config["save_interval"] eval_interval = train_config["eval_interval"] global_step = 1 for j in range(1, 1 + epoch): epoch_loss = 0. for i, batch in tqdm.tqdm(enumerate(loader, 1)): dv3.train() # CAUTION: don't forget to switch to train (text_sequences, text_lengths, text_positions, mel_specs, lin_specs, frames, decoder_positions, done_flags) = batch downsampled_mel_specs = F.strided_slice( mel_specs, axes=[1], starts=[0], ends=[mel_specs.shape[1]], strides=[downsample_factor]) mel_outputs, linear_outputs, alignments, done = dv3( text_sequences, text_positions, text_lengths, None, downsampled_mel_specs, decoder_positions) losses = criterion(mel_outputs, linear_outputs, done, alignments, downsampled_mel_specs, lin_specs, done_flags, text_lengths, frames) l = losses["loss"] l.backward() # record learning rate before updating writer.add_scalar("learning_rate", optim._learning_rate.step().numpy(), global_step) optim.minimize(l, grad_clip=gradient_clipper) optim.clear_gradients() # ==================all kinds of tedious things================= # record step loss into tensorboard epoch_loss += l.numpy()[0] step_loss = {k: v.numpy()[0] for k, v in losses.items()} for k, v in step_loss.items(): writer.add_scalar(k, v, global_step) # TODO: clean code # train state saving, the first sentence in the batch if global_step % snap_interval == 0: save_state(state_dir, writer, global_step, mel_input=downsampled_mel_specs, mel_output=mel_outputs, lin_input=lin_specs, lin_output=linear_outputs, alignments=alignments, win_length=win_length, hop_length=hop_length, min_level_db=min_level_db, ref_level_db=ref_level_db, power=power, n_iter=n_iter, preemphasis=preemphasis, sample_rate=sample_rate) # evaluation if global_step % eval_interval == 0: sentences = [ "Scientists at the CERN laboratory say they have discovered a new particle.", "There's a way to measure the acute emotional intelligence that has never gone out of style.", "President Trump met with other leaders at the Group of 20 conference.", "Generative adversarial network or variational auto-encoder.", "Please call Stella.", "Some have accepted this as a miracle without any physical explanation.", ] for idx, sent in enumerate(sentences): wav, attn = eval_model(dv3, sent, replace_pronounciation_prob, min_level_db, ref_level_db, power, n_iter, win_length, hop_length, preemphasis) wav_path = os.path.join( state_dir, "waveform", "eval_sample_{:09d}.wav".format(global_step)) sf.write(wav_path, wav, sample_rate) writer.add_audio("eval_sample_{}".format(idx), wav, global_step, sample_rate=sample_rate) attn_path = os.path.join( state_dir, "alignments", "eval_sample_attn_{:09d}.png".format(global_step)) plot_alignment(attn, attn_path) writer.add_image("eval_sample_attn{}".format(idx), cm.viridis(attn), global_step, dataformats="HWC") # save checkpoint if global_step % save_interval == 0: dg.save_dygraph( dv3.state_dict(), os.path.join(ckpt_dir, "model_step_{}".format(global_step))) dg.save_dygraph( optim.state_dict(), os.path.join(ckpt_dir, "model_step_{}".format(global_step))) global_step += 1 # epoch report writer.add_scalar("epoch_average_loss", epoch_loss / i, j) epoch_loss = 0.