# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import division import os import argparse import ruamel.yaml import numpy as np import matplotlib matplotlib.use("agg") from matplotlib import cm import matplotlib.pyplot as plt import tqdm import librosa from librosa import display import soundfile as sf from tensorboardX import SummaryWriter from paddle import fluid import paddle.fluid.layers as F import paddle.fluid.dygraph as dg from parakeet.g2p import en from parakeet.data import FilterDataset, TransformDataset, FilterDataset from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.utils.layer_tools import summary from parakeet.utils import io from data import LJSpeechMetaData, DataCollector, Transform from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a Deep Voice 3 model with LJSpeech dataset.") parser.add_argument("--config", type=str, help="experimrnt config") parser.add_argument( "--data", type=str, default="/workspace/datasets/LJSpeech-1.1/", help="The path of the LJSpeech dataset.") parser.add_argument("--device", type=int, default=-1, help="device to use") g = parser.add_mutually_exclusive_group() g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.") g.add_argument( "--iteration", type=int, help="the iteration of the checkpoint to load from output directory") parser.add_argument( "output", type=str, default="experiment", help="path to save results") args, _ = parser.parse_known_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) print("Command Line Args: ") for k, v in vars(args).items(): print("{}: {}".format(k, v)) # =========================dataset========================= # construct meta data data_root = args.data meta = LJSpeechMetaData(data_root) # filter it! min_text_length = config["meta_data"]["min_text_length"] meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length) # transform meta data into meta data transform_config = config["transform"] replace_pronounciation_prob = transform_config[ "replace_pronunciation_prob"] sample_rate = transform_config["sample_rate"] preemphasis = transform_config["preemphasis"] n_fft = transform_config["n_fft"] win_length = transform_config["win_length"] hop_length = transform_config["hop_length"] fmin = transform_config["fmin"] fmax = transform_config["fmax"] n_mels = transform_config["n_mels"] min_level_db = transform_config["min_level_db"] ref_level_db = transform_config["ref_level_db"] max_norm = transform_config["max_norm"] clip_norm = transform_config["clip_norm"] transform = Transform(replace_pronounciation_prob, sample_rate, preemphasis, n_fft, win_length, hop_length, fmin, fmax, n_mels, min_level_db, ref_level_db, max_norm, clip_norm) ljspeech = TransformDataset(meta, transform) # =========================dataiterator========================= # use meta data's text length as a sort key for the sampler train_config = config["train"] batch_size = train_config["batch_size"] text_lengths = [len(example[2]) for example in meta] sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths, batch_size) # some hyperparameters affect how we process data, so create a data collector! model_config = config["model"] downsample_factor = model_config["downsample_factor"] r = model_config["outputs_per_step"] collector = DataCollector(downsample_factor=downsample_factor, r=r) ljspeech_loader = DataCargo( ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler) # =========================model========================= if args.device == -1: place = fluid.CPUPlace() else: place = fluid.CUDAPlace(args.device) with dg.guard(place): # =========================model========================= n_speakers = model_config["n_speakers"] speaker_dim = model_config["speaker_embed_dim"] speaker_embed_std = model_config["speaker_embedding_weight_std"] n_vocab = en.n_vocab embed_dim = model_config["text_embed_dim"] linear_dim = 1 + n_fft // 2 use_decoder_states = model_config[ "use_decoder_state_for_postnet_input"] filter_size = model_config["kernel_size"] encoder_channels = model_config["encoder_channels"] decoder_channels = model_config["decoder_channels"] converter_channels = model_config["converter_channels"] dropout = model_config["dropout"] padding_idx = model_config["padding_idx"] embedding_std = model_config["embedding_weight_std"] max_positions = model_config["max_positions"] freeze_embedding = model_config["freeze_embedding"] trainable_positional_encodings = model_config[ "trainable_positional_encodings"] use_memory_mask = model_config["use_memory_mask"] query_position_rate = model_config["query_position_rate"] key_position_rate = model_config["key_position_rate"] window_backward = model_config["window_backward"] window_ahead = model_config["window_ahead"] key_projection = model_config["key_projection"] value_projection = model_config["value_projection"] dv3 = make_model( n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx, embedding_std, max_positions, n_vocab, freeze_embedding, filter_size, encoder_channels, n_mels, decoder_channels, r, trainable_positional_encodings, use_memory_mask, query_position_rate, key_position_rate, window_backward, window_ahead, key_projection, value_projection, downsample_factor, linear_dim, use_decoder_states, converter_channels, dropout) summary(dv3) # =========================loss========================= loss_config = config["loss"] masked_weight = loss_config["masked_loss_weight"] priority_freq = loss_config["priority_freq"] # Hz priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim) priority_freq_weight = loss_config["priority_freq_weight"] binary_divergence_weight = loss_config["binary_divergence_weight"] guided_attention_sigma = loss_config["guided_attention_sigma"] criterion = TTSLoss( masked_weight=masked_weight, priority_bin=priority_bin, priority_weight=priority_freq_weight, binary_divergence_weight=binary_divergence_weight, guided_attention_sigma=guided_attention_sigma, downsample_factor=downsample_factor, r=r) # =========================lr_scheduler========================= lr_config = config["lr_scheduler"] warmup_steps = lr_config["warmup_steps"] peak_learning_rate = lr_config["peak_learning_rate"] lr_scheduler = dg.NoamDecay( 1 / (warmup_steps * (peak_learning_rate)**2), warmup_steps) # =========================optimizer========================= optim_config = config["optimizer"] beta1 = optim_config["beta1"] beta2 = optim_config["beta2"] epsilon = optim_config["epsilon"] optim = fluid.optimizer.Adam( lr_scheduler, beta1, beta2, epsilon=epsilon, parameter_list=dv3.parameters(), grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1)) # generation synthesis_config = config["synthesis"] power = synthesis_config["power"] n_iter = synthesis_config["n_iter"] # =========================link(dataloader, paddle)========================= loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) loader.set_batch_generator(ljspeech_loader, places=place) # tensorboard & checkpoint preparation output_dir = args.output ckpt_dir = os.path.join(output_dir, "checkpoints") log_dir = os.path.join(output_dir, "log") state_dir = os.path.join(output_dir, "states") make_output_tree(output_dir) writer = SummaryWriter(logdir=log_dir) # load parameters and optimizer, and opdate iterations done sofar if args.checkpoint is not None: iteration = io.load_parameters( dv3, optim, checkpoint_path=args.checkpoint) else: iteration = io.load_parameters( dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration) # =========================train========================= max_iter = train_config["max_iteration"] snap_interval = train_config["snap_interval"] save_interval = train_config["save_interval"] eval_interval = train_config["eval_interval"] global_step = iteration + 1 iterator = iter(tqdm.tqdm(loader)) while global_step <= max_iter: try: batch = next(iterator) except StopIteration as e: iterator = iter(tqdm.tqdm(loader)) batch = next(iterator) dv3.train() (text_sequences, text_lengths, text_positions, mel_specs, lin_specs, frames, decoder_positions, done_flags) = batch downsampled_mel_specs = F.strided_slice( mel_specs, axes=[1], starts=[0], ends=[mel_specs.shape[1]], strides=[downsample_factor]) mel_outputs, linear_outputs, alignments, done = dv3( text_sequences, text_positions, text_lengths, None, downsampled_mel_specs, decoder_positions) losses = criterion(mel_outputs, linear_outputs, done, alignments, downsampled_mel_specs, lin_specs, done_flags, text_lengths, frames) l = losses["loss"] l.backward() # record learning rate before updating writer.add_scalar("learning_rate", optim._learning_rate.step().numpy(), global_step) optim.minimize(l) optim.clear_gradients() # ==================all kinds of tedious things================= # record step loss into tensorboard step_loss = { k: v.numpy()[0] for k, v in losses.items() if v is not None } tqdm.tqdm.write("global_step: {}\tloss: {}".format( global_step, step_loss["loss"])) for k, v in step_loss.items(): writer.add_scalar(k, v, global_step) # train state saving, the first sentence in the batch if global_step % snap_interval == 0: save_state( state_dir, writer, global_step, mel_input=downsampled_mel_specs, mel_output=mel_outputs, lin_input=lin_specs, lin_output=linear_outputs, alignments=alignments, win_length=win_length, hop_length=hop_length, min_level_db=min_level_db, ref_level_db=ref_level_db, power=power, n_iter=n_iter, preemphasis=preemphasis, sample_rate=sample_rate) # evaluation if global_step % eval_interval == 0: sentences = [ "Scientists at the CERN laboratory say they have discovered a new particle.", "There's a way to measure the acute emotional intelligence that has never gone out of style.", "President Trump met with other leaders at the Group of 20 conference.", "Generative adversarial network or variational auto-encoder.", "Please call Stella.", "Some have accepted this as a miracle without any physical explanation.", ] for idx, sent in enumerate(sentences): wav, attn = eval_model( dv3, sent, replace_pronounciation_prob, min_level_db, ref_level_db, power, n_iter, win_length, hop_length, preemphasis) wav_path = os.path.join( state_dir, "waveform", "eval_sample_{:09d}.wav".format(global_step)) sf.write(wav_path, wav, sample_rate) writer.add_audio( "eval_sample_{}".format(idx), wav, global_step, sample_rate=sample_rate) attn_path = os.path.join( state_dir, "alignments", "eval_sample_attn_{:09d}.png".format(global_step)) plot_alignment(attn, attn_path) writer.add_image( "eval_sample_attn{}".format(idx), cm.viridis(attn), global_step, dataformats="HWC") # save checkpoint if global_step % save_interval == 0: io.save_parameters(ckpt_dir, global_step, dv3, optim) global_step += 1