提交 77674353 编写于 作者: C chenfeiyu

update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training

上级 64790853
...@@ -28,24 +28,24 @@ Train the model using train.py, follow the usage displayed by `python train.py - ...@@ -28,24 +28,24 @@ Train the model using train.py, follow the usage displayed by `python train.py -
```text ```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT] usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--resume RESUME] [--wavenet WAVENET] [--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET]
train a ClariNet model with LJspeech and a trained WaveNet model. train a ClariNet model with LJspeech and a trained WaveNet model.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG path of the config file. --config CONFIG path of the config file.
--device DEVICE device to use. --device DEVICE device to use.
--output OUTPUT path to save student. --output OUTPUT path to save student.
--data DATA path of LJspeech dataset. --data DATA path of LJspeech dataset.
--resume RESUME checkpoint to load from. --checkpoint CHECKPOINT checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use. --wavenet WAVENET wavenet checkpoint to use.
``` ```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. - `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. - `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text ```text
├── checkpoints # checkpoint ├── checkpoints # checkpoint
...@@ -53,6 +53,8 @@ optional arguments: ...@@ -53,6 +53,8 @@ optional arguments:
└── log # tensorboard log └── log # tensorboard log
``` ```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU. - `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided. - `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
......
...@@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet ...@@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze from parakeet.utils.layer_tools import summary, freeze
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model from utils import valid_model, eval_model, load_model
sys.path.append("../wavenet") sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector from data import LJSpeechMetaData, Transform, DataCollector
......
...@@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet ...@@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet from utils import make_output_tree, valid_model, load_wavenet
sys.path.append("../wavenet") sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="train a clarinet model with LJspeech and a trained wavenet model." description="train a ClariNet model with LJspeech and a trained WaveNet model."
) )
parser.add_argument("--config", type=str, help="path of the config file.") parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument( parser.add_argument(
...@@ -48,13 +49,18 @@ if __name__ == "__main__": ...@@ -48,13 +49,18 @@ if __name__ == "__main__":
default="experiment", default="experiment",
help="path to save student.") help="path to save student.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.") parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument("--resume", type=str, help="checkpoint to load from.") parser.add_argument(
"--checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument( parser.add_argument(
"--wavenet", type=str, help="wavenet checkpoint to use.") "--wavenet", type=str, help="wavenet checkpoint to use.")
args = parser.parse_args() args = parser.parse_args()
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f) config = ruamel.yaml.safe_load(f)
print("Command Line args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data) ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"] data_config = config["data"]
...@@ -154,12 +160,38 @@ if __name__ == "__main__": ...@@ -154,12 +160,38 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm) gradiant_max_norm)
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." # train
if args.wavenet: max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# load wavenet/checkpoint, determine iterations done
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split('-')[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
if iteration == 0 and args.wavenet is None:
raise Exception(
"you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
)
if args.wavenet is not None and iteration > 0:
if args.checkpoint is None:
print("Resume training, --wavenet ignored")
else:
print("--checkpoint provided, --wavenet ignored")
if args.wavenet is not None and iteration == 0:
load_wavenet(model, args.wavenet) load_wavenet(model, args.wavenet)
if args.resume: # it may overwrite the wavenet loaded
load_checkpoint(model, optim, args.resume) io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
# loader # loader
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
...@@ -170,52 +202,43 @@ if __name__ == "__main__": ...@@ -170,52 +202,43 @@ if __name__ == "__main__":
capacity=10, return_list=True) capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place) valid_loader.set_batch_generator(valid_cargo, place)
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# training loop # training loop
global_step = 1 global_step = iteration + 1
global_epoch = 1 iterator = iter(tqdm(train_loader))
while global_step < max_iterations: while global_step < max_iterations:
epoch_loss = 0. try:
for j, batch in tqdm(enumerate(train_loader), desc="[train]"): batch = next(iterator)
audios, mels, audio_starts = batch except StopIteration as e:
model.train() iterator = iter(tqdm(train_loader))
loss_dict = model( batch = next(iterator)
audios, mels, audio_starts, clip_kl=global_step > 500)
audios, mels, audio_starts = batch
writer.add_scalar("learning_rate", model.train()
optim._learning_rate.step().numpy()[0], loss_dict = model(
global_step) audios, mels, audio_starts, clip_kl=global_step > 500)
for k, v in loss_dict.items():
writer.add_scalar("loss/{}".format(k), writer.add_scalar("learning_rate",
v.numpy()[0], global_step) optim._learning_rate.step().numpy()[0],
global_step)
l = loss_dict["loss"] for k, v in loss_dict.items():
step_loss = l.numpy()[0] writer.add_scalar("loss/{}".format(k),
print("[train] loss: {:<8.6f}".format(step_loss)) v.numpy()[0], global_step)
epoch_loss += step_loss
l = loss_dict["loss"]
l.backward() step_loss = l.numpy()[0]
optim.minimize(l, grad_clip=clipper) print("[train] loss: {:<8.6f}".format(step_loss))
optim.clear_gradients()
l.backward()
if global_step % eval_interval == 0: optim.minimize(l, grad_clip=clipper)
# evaluate on valid dataset optim.clear_gradients()
valid_model(model, valid_loader, state_dir, global_step,
sample_rate) if global_step % eval_interval == 0:
if global_step % checkpoint_interval == 0: # evaluate on valid dataset
save_checkpoint(model, optim, checkpoint_dir, global_step) valid_model(model, valid_loader, state_dir, global_step,
sample_rate)
global_step += 1 if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
# epoch loss optim)
average_loss = epoch_loss / j io.save_latest_checkpoint(checkpoint_dir, global_step)
writer.add_scalar("average_loss", average_loss, global_epoch)
global_epoch += 1 global_step += 1
...@@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed ...@@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
Train the model using train.py, follow the usage displayed by `python train.py --help`. Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text ```text
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE] usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT]
[-o OUTPUT] [-g DEVICE]
Train a Deep Voice 3 model with LJSpeech dataset. Train a Deep Voice 3 model with LJSpeech dataset.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
-c CONFIG, --config CONFIG -c CONFIG, --config CONFIG experimrnt config
experimrnt config -s DATA, --data DATA The path of the LJSpeech dataset.
-s DATA, --data DATA The path of the LJSpeech dataset. --checkpoint CHECKPOINT checkpoint to load
-r RESUME, --resume RESUME -o OUTPUT, --output OUTPUT The directory to save result.
checkpoint to load -g DEVICE, --device DEVICE device to use
-o OUTPUT, --output OUTPUT
The directory to save result.
-g DEVICE, --device DEVICE
device to use
``` ```
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. - `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below. - `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
```text ```text
...@@ -67,6 +64,8 @@ optional arguments: ...@@ -67,6 +64,8 @@ optional arguments:
└── waveform # waveform (.wav files) └── waveform # waveform (.wav files)
``` ```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU. - `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script: Example script:
......
...@@ -83,7 +83,7 @@ lr_scheduler: ...@@ -83,7 +83,7 @@ lr_scheduler:
train: train:
batch_size: 16 batch_size: 16
epochs: 2000 max_iteration: 2000000
snap_interval: 1000 snap_interval: 1000
eval_interval: 10000 eval_interval: 10000
......
...@@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg ...@@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from parakeet.g2p import en from parakeet.g2p import en
from parakeet.utils.layer_tools import summary
from parakeet.modules.weight_norm import WeightNormWrapper from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from parakeet.utils.io import load_parameters
from utils import make_model, eval_model, plot_alignment from utils import make_model, eval_model, plot_alignment
...@@ -44,6 +45,10 @@ if __name__ == "__main__": ...@@ -44,6 +45,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f) config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
if args.device == -1: if args.device == -1:
place = fluid.CPUPlace() place = fluid.CPUPlace()
else: else:
......
...@@ -17,6 +17,8 @@ import os ...@@ -17,6 +17,8 @@ import os
import argparse import argparse
import ruamel.yaml import ruamel.yaml
import numpy as np import numpy as np
import matplotlib
matplotlib.use("agg")
from matplotlib import cm from matplotlib import cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tqdm import tqdm
...@@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, ...@@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, DataCollector, Transform from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.") description="Train a Deep Voice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config") parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument( parser.add_argument(
"-s", "-s",
...@@ -49,7 +52,7 @@ if __name__ == "__main__": ...@@ -49,7 +52,7 @@ if __name__ == "__main__":
type=str, type=str,
default="/workspace/datasets/LJSpeech-1.1/", default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.") help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load") parser.add_argument("--checkpoint", type=str, help="checkpoint to load")
parser.add_argument( parser.add_argument(
"-o", "-o",
"--output", "--output",
...@@ -62,6 +65,10 @@ if __name__ == "__main__": ...@@ -62,6 +65,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f) config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
# =========================dataset========================= # =========================dataset=========================
# construct meta data # construct meta data
data_root = args.data data_root = args.data
...@@ -151,6 +158,7 @@ if __name__ == "__main__": ...@@ -151,6 +158,7 @@ if __name__ == "__main__":
query_position_rate, key_position_rate, window_backward, query_position_rate, key_position_rate, window_backward,
window_ahead, key_projection, value_projection, downsample_factor, window_ahead, key_projection, value_projection, downsample_factor,
linear_dim, use_decoder_states, converter_channels, dropout) linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
# =========================loss========================= # =========================loss=========================
loss_config = config["loss"] loss_config = config["loss"]
...@@ -195,7 +203,6 @@ if __name__ == "__main__": ...@@ -195,7 +203,6 @@ if __name__ == "__main__":
n_iter = synthesis_config["n_iter"] n_iter = synthesis_config["n_iter"]
# =========================link(dataloader, paddle)========================= # =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True) capacity=10, return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place) loader.set_batch_generator(ljspeech_loader, places=place)
...@@ -208,122 +215,117 @@ if __name__ == "__main__": ...@@ -208,122 +215,117 @@ if __name__ == "__main__":
make_output_tree(output_dir) make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir) writer = SummaryWriter(logdir=log_dir)
# load model parameters # load parameters and optimizer, and opdate iterations done sofar
resume_path = args.resume io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint)
if resume_path is not None: if args.checkpoint is not None:
state, _ = dg.load_dygraph(args.resume) iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
dv3.set_dict(state) else:
iteration = io.load_latest_checkpoint(ckpt_dir)
# =========================train========================= # =========================train=========================
epoch = train_config["epochs"] max_iter = train_config["max_iteration"]
snap_interval = train_config["snap_interval"] snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"] save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"] eval_interval = train_config["eval_interval"]
global_step = 1 global_step = iteration + 1
iterator = iter(tqdm.tqdm(loader))
while global_step <= max_iter:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(loader))
batch = next(iterator)
dv3.train()
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
for j in range(1, 1 + epoch): losses = criterion(mel_outputs, linear_outputs, done, alignments,
epoch_loss = 0. downsampled_mel_specs, lin_specs, done_flags,
for i, batch in tqdm.tqdm(enumerate(loader, 1)): text_lengths, frames)
dv3.train() # CAUTION: don't forget to switch to train l = losses["loss"]
(text_sequences, text_lengths, text_positions, mel_specs, l.backward()
lin_specs, frames, decoder_positions, done_flags) = batch # record learning rate before updating
downsampled_mel_specs = F.strided_slice( writer.add_scalar("learning_rate",
mel_specs, optim._learning_rate.step().numpy(), global_step)
axes=[1], optim.minimize(l, grad_clip=gradient_clipper)
starts=[0], optim.clear_gradients()
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done, # ==================all kinds of tedious things=================
alignments, downsampled_mel_specs, # record step loss into tensorboard
lin_specs, done_flags, text_lengths, frames) step_loss = {k: v.numpy()[0] for k, v in losses.items()}
l = losses["loss"] tqdm.tqdm.write("global_step: {}\tloss: {}".format(
l.backward() global_step, step_loss["loss"]))
# record learning rate before updating for k, v in step_loss.items():
writer.add_scalar("learning_rate", writer.add_scalar(k, v, global_step)
optim._learning_rate.step().numpy(),
global_step)
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
# ==================all kinds of tedious things================= # train state saving, the first sentence in the batch
# record step loss into tensorboard if global_step % snap_interval == 0:
epoch_loss += l.numpy()[0] save_state(
step_loss = {k: v.numpy()[0] for k, v in losses.items()} state_dir,
for k, v in step_loss.items(): writer,
writer.add_scalar(k, v, global_step) global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
# TODO: clean code # evaluation
# train state saving, the first sentence in the batch if global_step % eval_interval == 0:
if global_step % snap_interval == 0: sentences = [
save_state( "Scientists at the CERN laboratory say they have discovered a new particle.",
state_dir, "There's a way to measure the acute emotional intelligence that has never gone out of style.",
writer, "President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob, min_level_db,
ref_level_db, power, n_iter, win_length, hop_length,
preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio(
"eval_sample_{}".format(idx),
wav,
global_step, global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate) sample_rate=sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image(
"eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# evaluation # save checkpoint
if global_step % eval_interval == 0: if global_step % save_interval == 0:
sentences = [ io.save_latest_parameters(ckpt_dir, global_step, dv3, optim)
"Scientists at the CERN laboratory say they have discovered a new particle.", io.save_latest_checkpoint(ckpt_dir, global_step)
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob,
min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio(
"eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image(
"eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(
dv3.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
dg.save_dygraph(
optim.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
global_step += 1 global_step += 1
# epoch report
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss = 0.
...@@ -28,22 +28,22 @@ Train the model using train.py. For help on usage, try `python train.py --help`. ...@@ -28,22 +28,22 @@ Train the model using train.py. For help on usage, try `python train.py --help`.
```text ```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT] usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--resume RESUME] [--device DEVICE] [--checkpoint CHECKPOINT]
Train a WaveNet model with LJSpeech. Train a WaveNet model with LJSpeech.
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--data DATA path of the LJspeech dataset. --data DATA path of the LJspeech dataset.
--config CONFIG path of the config file. --config CONFIG path of the config file.
--output OUTPUT path to save results. --output OUTPUT path to save results.
--device DEVICE device to use. --device DEVICE device to use.
--resume RESUME checkpoint to resume from. --checkpoint CHECKPOINT checkpoint to resume from.
``` ```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training. - `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. - `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text ```text
...@@ -51,6 +51,8 @@ optional arguments: ...@@ -51,6 +51,8 @@ optional arguments:
└── log # tensorboard log └── log # tensorboard log
``` ```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU. - `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script: Example script:
......
...@@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet ...@@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, Transform, DataCollector from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model, save_checkpoint from utils import make_output_tree, valid_model, eval_model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -87,7 +87,8 @@ if __name__ == "__main__": ...@@ -87,7 +87,8 @@ if __name__ == "__main__":
batch_size=1, batch_size=1,
sampler=SequentialSampler(ljspeech_valid)) sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output) if not os.path.exists(args.output):
os.makedirs(args.output)
if args.device == -1: if args.device == -1:
place = fluid.CPUPlace() place = fluid.CPUPlace()
......
...@@ -16,7 +16,7 @@ from __future__ import division ...@@ -16,7 +16,7 @@ from __future__ import division
import os import os
import ruamel.yaml import ruamel.yaml
import argparse import argparse
from tqdm import tqdm import tqdm
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from paddle import fluid from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
...@@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg ...@@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, save_checkpoint from utils import make_output_tree, valid_model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a wavenet model with LJSpeech.") description="Train a WaveNet model with LJSpeech.")
parser.add_argument( parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.") "--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.") parser.add_argument("--config", type=str, help="path of the config file.")
...@@ -42,12 +43,16 @@ if __name__ == "__main__": ...@@ -42,12 +43,16 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--device", type=int, default=-1, help="device to use.") "--device", type=int, default=-1, help="device to use.")
parser.add_argument( parser.add_argument(
"--resume", type=str, help="checkpoint to resume from.") "--checkpoint", type=str, help="checkpoint to resume from.")
args = parser.parse_args() args = parser.parse_args()
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f) config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data) ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"] data_config = config["data"]
...@@ -126,14 +131,6 @@ if __name__ == "__main__": ...@@ -126,14 +131,6 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm) gradiant_max_norm)
if args.resume:
model_dict, optim_dict = dg.load_dygraph(args.resume)
print("Loading from {}.pdparams".format(args.resume))
model.set_dict(model_dict)
if optim_dict:
optim.set_dict(optim_dict)
print("Loading from {}.pdopt".format(args.resume))
train_loader = fluid.io.DataLoader.from_generator( train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True) capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place) train_loader.set_batch_generator(train_cargo, place)
...@@ -150,33 +147,48 @@ if __name__ == "__main__": ...@@ -150,33 +147,48 @@ if __name__ == "__main__":
log_dir = os.path.join(args.output, "log") log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir) writer = SummaryWriter(log_dir)
global_step = 1 # load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
global_step = iteration + 1
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iterations: while global_step <= max_iterations:
epoch_loss = 0. print(global_step)
for i, batch in tqdm(enumerate(train_loader)): try:
audio_clips, mel_specs, audio_starts = batch batch = next(iterator)
except StopIteration as e:
model.train() iterator = iter(tqdm.tqdm(train_loader))
y_var = model(audio_clips, mel_specs, audio_starts) batch = next(iterator)
loss_var = model.loss(y_var, audio_clips)
loss_var.backward() audio_clips, mel_specs, audio_starts = batch
loss_np = loss_var.numpy()
model.train()
epoch_loss += loss_np[0] y_var = model(audio_clips, mel_specs, audio_starts)
loss_var = model.loss(y_var, audio_clips)
writer.add_scalar("loss", loss_np[0], global_step) loss_var.backward()
writer.add_scalar("learning_rate", loss_np = loss_var.numpy()
optim._learning_rate.step().numpy()[0],
global_step) writer.add_scalar("loss", loss_np[0], global_step)
optim.minimize(loss_var, grad_clip=clipper) writer.add_scalar("learning_rate",
optim.clear_gradients() optim._learning_rate.step().numpy()[0],
print("loss: {:<8.6f}".format(loss_np[0])) global_step)
optim.minimize(loss_var, grad_clip=clipper)
if global_step % snap_interval == 0: optim.clear_gradients()
valid_model(model, valid_loader, writer, global_step, print("global_step: {}\tloss: {:<8.6f}".format(global_step,
sample_rate) loss_np[0]))
if global_step % checkpoint_interval == 0: if global_step % snap_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step) valid_model(model, valid_loader, writer, global_step,
sample_rate)
global_step += 1
if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
global_step += 1
...@@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate): ...@@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
wav_np = wav_var.numpy()[0] wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate) sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path)) print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
checkpoint_path = os.path.join(checkpoint_dir,
"step_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), checkpoint_path)
dg.save_dygraph(optim.state_dict(), checkpoint_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册