提交 77674353 编写于 作者: C chenfeiyu

update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training

上级 64790853
......@@ -28,24 +28,24 @@ Train the model using train.py, follow the usage displayed by `python train.py -
```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
[--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET]
train a ClariNet model with LJspeech and a trained WaveNet model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--resume RESUME checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--checkpoint CHECKPOINT checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
......@@ -53,6 +53,8 @@ optional arguments:
└── log # tensorboard log
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
......
......@@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model
from utils import valid_model, eval_model, load_model
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
......
......@@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet
from utils import make_output_tree, valid_model, load_wavenet
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="train a clarinet model with LJspeech and a trained wavenet model."
description="train a ClariNet model with LJspeech and a trained WaveNet model."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
......@@ -48,13 +49,18 @@ if __name__ == "__main__":
default="experiment",
help="path to save student.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument("--resume", type=str, help="checkpoint to load from.")
parser.add_argument(
"--checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument(
"--wavenet", type=str, help="wavenet checkpoint to use.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
......@@ -154,12 +160,38 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
if args.wavenet:
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# load wavenet/checkpoint, determine iterations done
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split('-')[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
if iteration == 0 and args.wavenet is None:
raise Exception(
"you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
)
if args.wavenet is not None and iteration > 0:
if args.checkpoint is None:
print("Resume training, --wavenet ignored")
else:
print("--checkpoint provided, --wavenet ignored")
if args.wavenet is not None and iteration == 0:
load_wavenet(model, args.wavenet)
if args.resume:
load_checkpoint(model, optim, args.resume)
# it may overwrite the wavenet loaded
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
# loader
train_loader = fluid.io.DataLoader.from_generator(
......@@ -170,52 +202,43 @@ if __name__ == "__main__":
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# training loop
global_step = 1
global_epoch = 1
global_step = iteration + 1
iterator = iter(tqdm(train_loader))
while global_step < max_iterations:
epoch_loss = 0.
for j, batch in tqdm(enumerate(train_loader), desc="[train]"):
audios, mels, audio_starts = batch
model.train()
loss_dict = model(
audios, mels, audio_starts, clip_kl=global_step > 500)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
for k, v in loss_dict.items():
writer.add_scalar("loss/{}".format(k),
v.numpy()[0], global_step)
l = loss_dict["loss"]
step_loss = l.numpy()[0]
print("[train] loss: {:<8.6f}".format(step_loss))
epoch_loss += step_loss
l.backward()
optim.minimize(l, grad_clip=clipper)
optim.clear_gradients()
if global_step % eval_interval == 0:
# evaluate on valid dataset
valid_model(model, valid_loader, state_dir, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
global_step += 1
# epoch loss
average_loss = epoch_loss / j
writer.add_scalar("average_loss", average_loss, global_epoch)
global_epoch += 1
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(train_loader))
batch = next(iterator)
audios, mels, audio_starts = batch
model.train()
loss_dict = model(
audios, mels, audio_starts, clip_kl=global_step > 500)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
for k, v in loss_dict.items():
writer.add_scalar("loss/{}".format(k),
v.numpy()[0], global_step)
l = loss_dict["loss"]
step_loss = l.numpy()[0]
print("[train] loss: {:<8.6f}".format(step_loss))
l.backward()
optim.minimize(l, grad_clip=clipper)
optim.clear_gradients()
if global_step % eval_interval == 0:
# evaluate on valid dataset
valid_model(model, valid_loader, state_dir, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
global_step += 1
......@@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE]
usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT]
[-o OUTPUT] [-g DEVICE]
Train a Deep Voice 3 model with LJSpeech dataset.
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
experimrnt config
-s DATA, --data DATA The path of the LJSpeech dataset.
-r RESUME, --resume RESUME
checkpoint to load
-o OUTPUT, --output OUTPUT
The directory to save result.
-g DEVICE, --device DEVICE
device to use
-h, --help show this help message and exit
-c CONFIG, --config CONFIG experimrnt config
-s DATA, --data DATA The path of the LJSpeech dataset.
--checkpoint CHECKPOINT checkpoint to load
-o OUTPUT, --output OUTPUT The directory to save result.
-g DEVICE, --device DEVICE device to use
```
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
```text
......@@ -67,6 +64,8 @@ optional arguments:
└── waveform # waveform (.wav files)
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:
......
......@@ -83,7 +83,7 @@ lr_scheduler:
train:
batch_size: 16
epochs: 2000
max_iteration: 2000000
snap_interval: 1000
eval_interval: 10000
......
......@@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg
from tensorboardX import SummaryWriter
from parakeet.g2p import en
from parakeet.utils.layer_tools import summary
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from parakeet.utils.io import load_parameters
from utils import make_model, eval_model, plot_alignment
......@@ -44,6 +45,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
if args.device == -1:
place = fluid.CPUPlace()
else:
......
......@@ -17,6 +17,8 @@ import os
import argparse
import ruamel.yaml
import numpy as np
import matplotlib
matplotlib.use("agg")
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
......@@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
description="Train a Deep Voice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument(
"-s",
......@@ -49,7 +52,7 @@ if __name__ == "__main__":
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("--checkpoint", type=str, help="checkpoint to load")
parser.add_argument(
"-o",
"--output",
......@@ -62,6 +65,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
# =========================dataset=========================
# construct meta data
data_root = args.data
......@@ -151,6 +158,7 @@ if __name__ == "__main__":
query_position_rate, key_position_rate, window_backward,
window_ahead, key_projection, value_projection, downsample_factor,
linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
# =========================loss=========================
loss_config = config["loss"]
......@@ -195,7 +203,6 @@ if __name__ == "__main__":
n_iter = synthesis_config["n_iter"]
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
......@@ -208,122 +215,117 @@ if __name__ == "__main__":
make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir)
# load model parameters
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
else:
iteration = io.load_latest_checkpoint(ckpt_dir)
# =========================train=========================
epoch = train_config["epochs"]
max_iter = train_config["max_iteration"]
snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"]
global_step = 1
global_step = iteration + 1
iterator = iter(tqdm.tqdm(loader))
while global_step <= max_iter:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(loader))
batch = next(iterator)
dv3.train()
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
for j in range(1, 1 + epoch):
epoch_loss = 0.
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done, alignments,
downsampled_mel_specs, lin_specs, done_flags,
text_lengths, frames)
l = losses["loss"]
l.backward()
# record learning rate before updating
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy(), global_step)
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
l = losses["loss"]
l.backward()
# record learning rate before updating
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy(),
global_step)
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
# ==================all kinds of tedious things=================
# record step loss into tensorboard
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
global_step, step_loss["loss"]))
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# ==================all kinds of tedious things=================
# record step loss into tensorboard
epoch_loss += l.numpy()[0]
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(
state_dir,
writer,
global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(
state_dir,
writer,
# evaluation
if global_step % eval_interval == 0:
sentences = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob, min_level_db,
ref_level_db, power, n_iter, win_length, hop_length,
preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio(
"eval_sample_{}".format(idx),
wav,
global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image(
"eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# evaluation
if global_step % eval_interval == 0:
sentences = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob,
min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio(
"eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image(
"eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(
dv3.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
dg.save_dygraph(
optim.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
# save checkpoint
if global_step % save_interval == 0:
io.save_latest_parameters(ckpt_dir, global_step, dv3, optim)
io.save_latest_checkpoint(ckpt_dir, global_step)
global_step += 1
# epoch report
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss = 0.
global_step += 1
......@@ -28,22 +28,22 @@ Train the model using train.py. For help on usage, try `python train.py --help`.
```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--resume RESUME]
[--device DEVICE] [--checkpoint CHECKPOINT]
Train a WaveNet model with LJSpeech.
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--resume RESUME checkpoint to resume from.
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--checkpoint CHECKPOINT checkpoint to resume from.
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
......@@ -51,6 +51,8 @@ optional arguments:
└── log # tensorboard log
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:
......
......@@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
from utils import make_output_tree, valid_model, eval_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
......@@ -87,7 +87,8 @@ if __name__ == "__main__":
batch_size=1,
sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output)
if not os.path.exists(args.output):
os.makedirs(args.output)
if args.device == -1:
place = fluid.CPUPlace()
......
......@@ -16,7 +16,7 @@ from __future__ import division
import os
import ruamel.yaml
import argparse
from tqdm import tqdm
import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
......@@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, save_checkpoint
from utils import make_output_tree, valid_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a wavenet model with LJSpeech.")
description="Train a WaveNet model with LJSpeech.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
......@@ -42,12 +43,16 @@ if __name__ == "__main__":
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument(
"--resume", type=str, help="checkpoint to resume from.")
"--checkpoint", type=str, help="checkpoint to resume from.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
......@@ -126,14 +131,6 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
if args.resume:
model_dict, optim_dict = dg.load_dygraph(args.resume)
print("Loading from {}.pdparams".format(args.resume))
model.set_dict(model_dict)
if optim_dict:
optim.set_dict(optim_dict)
print("Loading from {}.pdopt".format(args.resume))
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
......@@ -150,33 +147,48 @@ if __name__ == "__main__":
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
global_step = 1
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
global_step = iteration + 1
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iterations:
epoch_loss = 0.
for i, batch in tqdm(enumerate(train_loader)):
audio_clips, mel_specs, audio_starts = batch
model.train()
y_var = model(audio_clips, mel_specs, audio_starts)
loss_var = model.loss(y_var, audio_clips)
loss_var.backward()
loss_np = loss_var.numpy()
epoch_loss += loss_np[0]
writer.add_scalar("loss", loss_np[0], global_step)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
optim.minimize(loss_var, grad_clip=clipper)
optim.clear_gradients()
print("loss: {:<8.6f}".format(loss_np[0]))
if global_step % snap_interval == 0:
valid_model(model, valid_loader, writer, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
global_step += 1
print(global_step)
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(train_loader))
batch = next(iterator)
audio_clips, mel_specs, audio_starts = batch
model.train()
y_var = model(audio_clips, mel_specs, audio_starts)
loss_var = model.loss(y_var, audio_clips)
loss_var.backward()
loss_np = loss_var.numpy()
writer.add_scalar("loss", loss_np[0], global_step)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
optim.minimize(loss_var, grad_clip=clipper)
optim.clear_gradients()
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
loss_np[0]))
if global_step % snap_interval == 0:
valid_model(model, valid_loader, writer, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
global_step += 1
......@@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
checkpoint_path = os.path.join(checkpoint_dir,
"step_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), checkpoint_path)
dg.save_dygraph(optim.state_dict(), checkpoint_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册