diff --git a/README.md b/README.md index 9eddef4bf5b70768fee034efea6482158966299f..aef196366d4b133d48aab68983dec7f57f298f78 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Parakeet aims to provide a flexible, efficient and state-of-the-art text-to-spee In particular, it features the latest [WaveFlow] (https://arxiv.org/abs/1912.01219) model proposed by Baidu Research. - WaveFlow can synthesize 22.05 kHz high-fidelity speech around 40x faster than real-time on a Nvidia V100 GPU without engineered inference kernels, which is faster than [WaveGlow] (https://github.com/NVIDIA/waveglow) and serveral orders of magnitude faster than WaveNet. - WaveFlow is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M) and comparable to WaveNet (4.6M). -- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. +- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. ### Setup @@ -45,8 +45,10 @@ nltk.download("cmudict") - [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654) - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) -- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). +- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) - [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219) +- [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499) +- [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) ## Examples @@ -54,6 +56,8 @@ nltk.download("cmudict") - [Train a TransformerTTS model with ljspeech dataset](./examples/transformer_tts) - [Train a FastSpeech model with ljspeech dataset](./examples/fastspeech) - [Train a WaveFlow model with ljspeech dataset](./examples/waveflow) +- [Train a WaveNet model with ljspeech dataset](./examples/wavenet) +- [Train a Clarinet model with ljspeech dataset](./examples/clarinet) ## Copyright and License diff --git a/examples/clarinet/configs/clarinet_ljspeech.yaml b/examples/clarinet/configs/clarinet_ljspeech.yaml index 7ceedcc6f6b3e24fc536da9d3bac33cdcac410e7..2e571e5cb1c7195acff94f457187e6db0aa56bcb 100644 --- a/examples/clarinet/configs/clarinet_ljspeech.yaml +++ b/examples/clarinet/configs/clarinet_ljspeech.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 8 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 diff --git a/examples/clarinet/synthesis.py b/examples/clarinet/synthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..e22723798e89a419fd39771db87146e22c1e287e --- /dev/null +++ b/examples/clarinet/synthesis.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="synthesize audio files from mel spectrogram in the validation set." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument( + "checkpoint", type=str, help="checkpoint to load from.") + parser.add_argument( + "output", type=str, default="experiment", help="path to save student.") + + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + # load & freeze upsample_net & teacher + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + load_model(model, args.checkpoint) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + if not os.path.exists(args.output): + os.makedirs(args.output) + eval_model(model, valid_loader, args.output, sample_rate) diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1ceb05c2d35af0f1efc3fe909af92d0dbe2057bc --- /dev/null +++ b/examples/clarinet/train.py @@ -0,0 +1,220 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="train a clarinet model with LJspeech and a trained wavenet model." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument( + "--output", + type=str, + default="experiment", + help="path to save student.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument("--resume", type=str, help="checkpoint to load from.") + parser.add_argument( + "--wavenet", type=str, help="wavenet checkpoint to use.") + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + make_output_tree(args.output) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + + # optim + train_config = config["train"] + learning_rate = train_config["learning_rate"] + anneal_rate = train_config["anneal_rate"] + anneal_interval = train_config["anneal_interval"] + lr_scheduler = dg.ExponentialDecay( + learning_rate, anneal_interval, anneal_rate, staircase=True) + optim = fluid.optimizer.Adam( + lr_scheduler, parameter_list=model.parameters()) + gradiant_max_norm = train_config["gradient_max_norm"] + clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + gradiant_max_norm) + + assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." + if args.wavenet: + load_wavenet(model, args.wavenet) + + if args.resume: + load_checkpoint(model, optim, args.resume) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + # train + max_iterations = train_config["max_iterations"] + checkpoint_interval = train_config["checkpoint_interval"] + eval_interval = train_config["eval_interval"] + checkpoint_dir = os.path.join(args.output, "checkpoints") + state_dir = os.path.join(args.output, "states") + log_dir = os.path.join(args.output, "log") + writer = SummaryWriter(log_dir) + + # training loop + global_step = 1 + global_epoch = 1 + while global_step < max_iterations: + epoch_loss = 0. + for j, batch in tqdm(enumerate(train_loader), desc="[train]"): + audios, mels, audio_starts = batch + model.train() + loss_dict = model( + audios, mels, audio_starts, clip_kl=global_step > 500) + + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy()[0], + global_step) + for k, v in loss_dict.items(): + writer.add_scalar("loss/{}".format(k), + v.numpy()[0], global_step) + + l = loss_dict["loss"] + step_loss = l.numpy()[0] + print("[train] loss: {:<8.6f}".format(step_loss)) + epoch_loss += step_loss + + l.backward() + optim.minimize(l, grad_clip=clipper) + optim.clear_gradients() + + if global_step % eval_interval == 0: + # evaluate on valid dataset + valid_model(model, valid_loader, state_dir, global_step, + sample_rate) + if global_step % checkpoint_interval == 0: + save_checkpoint(model, optim, checkpoint_dir, global_step) + + global_step += 1 + + # epoch loss + average_loss = epoch_loss / j + writer.add_scalar("average_loss", average_loss, global_epoch) + global_epoch += 1 diff --git a/examples/clarinet/utils.py b/examples/clarinet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ec74615ece50a2d2f5ae9a934c96f1334e30f8 --- /dev/null +++ b/examples/clarinet/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import soundfile as sf +from tensorboardX import SummaryWriter +from collections import OrderedDict + +from paddle import fluid +import paddle.fluid.dygraph as dg + + +def make_output_tree(output_dir): + checkpoint_dir = os.path.join(output_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + state_dir = os.path.join(output_dir, "states") + if not os.path.exists(state_dir): + os.makedirs(state_dir) + + +def valid_model(model, valid_loader, output_dir, global_step, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, + "step_{}_sentence_{}.wav".format(global_step, i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def eval_model(model, valid_loader, output_dir, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, "sentence_{}.wav".format(i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def save_checkpoint(model, optim, checkpoint_dir, global_step): + path = os.path.join(checkpoint_dir, "step_{}".format(global_step)) + dg.save_dygraph(model.state_dict(), path) + print("saving model to {}".format(path + ".pdparams")) + if optim: + dg.save_dygraph(optim.state_dict(), path) + print("saving optimizer to {}".format(path + ".pdopt")) + + +def load_model(model, path): + model_dict, _ = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + + +def load_checkpoint(model, optim, path): + model_dict, optim_dict = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + if optim_dict: + optim.set_dict(optim_dict) + print("loaded optimizer from {}.pdparams".format(path)) + + +def load_wavenet(model, path): + wavenet_dict, _ = dg.load_dygraph(path) + encoder_dict = OrderedDict() + teacher_dict = OrderedDict() + for k, v in wavenet_dict.items(): + if k.startswith("encoder."): + encoder_dict[k.split('.', 1)[1]] = v + else: + # k starts with "decoder." + teacher_dict[k.split('.', 1)[1]] = v + + model.encoder.set_dict(encoder_dict) + model.teacher.set_dict(teacher_dict) + print("loaded the encoder part and teacher part from wavenet model.") diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index 80434ce9e7f6ed66eb95c8e8d0f2d1bbec9b5f7f..fa7a5e44649cbd35a29ae1e53c1174f46e8309eb 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -23,7 +23,7 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed ```text ├── data.py data_processing -├── ljspeech.yaml (example) configuration file +├── configs/ (example) configuration files ├── sentences.txt sample sentences ├── synthesis.py script to synthesize waveform from text ├── train.py script to train a model @@ -72,7 +72,7 @@ optional arguments: Example script: ```bash -python train.py --config=./ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 +python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 ``` You can monitor training log via tensorboard, using the script below. @@ -110,5 +110,5 @@ optional arguments: Example script: ```bash -python synthesis.py --config=./ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated +python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated ``` diff --git a/examples/deepvoice3/ljspeech.yaml b/examples/deepvoice3/configs/ljspeech.yaml similarity index 100% rename from examples/deepvoice3/ljspeech.yaml rename to examples/deepvoice3/configs/ljspeech.yaml diff --git a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml index a848a52ec5b2541ddffb1f3960dcc9d07c2071ad..68936ee9311eb77586af6b2b2b73f6874b57d43e 100644 --- a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml +++ b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000 diff --git a/examples/wavenet/configs/wavenet_single_gaussian.yaml b/examples/wavenet/configs/wavenet_single_gaussian.yaml index 8e333492900f5a48bf0a8d1312ca9a6e0ee00fe9..484db0bdf6b14d00532165aef2a27f42b1ba6de7 100644 --- a/examples/wavenet/configs/wavenet_single_gaussian.yaml +++ b/examples/wavenet/configs/wavenet_single_gaussian.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000 diff --git a/examples/wavenet/configs/wavenet_softmax.yaml b/examples/wavenet/configs/wavenet_softmax.yaml index 98018ee25d99ff685f91de734d1357bd359013bf..7e9d7567dbaecb8968ec2e82438a71c2fd27ec70 100644 --- a/examples/wavenet/configs/wavenet_softmax.yaml +++ b/examples/wavenet/configs/wavenet_softmax.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000