train.py 6.3 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import division
16
import time
C
chenfeiyu 已提交
17 18
import os
import argparse
19
import ruamel.yaml
C
chenfeiyu 已提交
20 21 22
import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
23
fluid.require_version('1.8.0')
C
chenfeiyu 已提交
24 25
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
26
from parakeet.utils.io import load_parameters, save_parameters
C
chenfeiyu 已提交
27

28 29 30
from data import make_data_loader
from model import make_model, make_criterion, make_optimizer
from utils import make_output_tree, add_options, get_place, Evaluator, StateSaver, make_evaluator, make_state_saver
C
chenfeiyu 已提交
31 32 33

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
34
        description="Train a Deep Voice 3 model with LJSpeech dataset.")
35 36
    add_options(parser)
    args, _ = parser.parse_known_args()
C
chenfeiyu 已提交
37

38 39 40 41 42 43 44 45
    # only use args.device when training in single process
    # when training with distributed.launch, devices are provided by
    # `--selected_gpus` for distributed.launch
    env = dg.parallel.ParallelEnv()
    device_id = env.dev_id if env.nranks > 1 else args.device
    place = get_place(device_id)
    # start dygraph
    dg.enable_dygraph(place)
C
chenfeiyu 已提交
46

C
chenfeiyu 已提交
47 48 49
    with open(args.config, 'rt') as f:
        config = ruamel.yaml.safe_load(f)

50 51 52 53
    print("Command Line Args: ")
    for k, v in vars(args).items():
        print("{}: {}".format(k, v))

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    data_loader = make_data_loader(args.data, config)
    model = make_model(config)
    if env.nranks > 1:
        strategy = dg.parallel.prepare_context()
        model = dg.DataParallel(model, strategy)
    criterion = make_criterion(config)
    optim = make_optimizer(model, config)

    # generation
    synthesis_config = config["synthesis"]
    power = synthesis_config["power"]
    n_iter = synthesis_config["n_iter"]

    # tensorboard & checkpoint preparation
    output_dir = args.output
    ckpt_dir = os.path.join(output_dir, "checkpoints")
    log_dir = os.path.join(output_dir, "log")
    state_dir = os.path.join(output_dir, "states")
    eval_dir = os.path.join(output_dir, "eval")
    if env.local_rank == 0:
C
chenfeiyu 已提交
74 75
        make_output_tree(output_dir)
        writer = SummaryWriter(logdir=log_dir)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    else:
        writer = None
    sentences = [
        "Scientists at the CERN laboratory say they have discovered a new particle.",
        "There's a way to measure the acute emotional intelligence that has never gone out of style.",
        "President Trump met with other leaders at the Group of 20 conference.",
        "Generative adversarial network or variational auto-encoder.",
        "Please call Stella.",
        "Some have accepted this as a miracle without any physical explanation.",
    ]
    evaluator = make_evaluator(config, sentences, eval_dir, writer)
    state_saver = make_state_saver(config, state_dir, writer)

    # load parameters and optimizer, and opdate iterations done sofar
    if args.checkpoint is not None:
        iteration = load_parameters(
            model, optim, checkpoint_path=args.checkpoint)
    else:
        iteration = load_parameters(
            model, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
C
chenfeiyu 已提交
96

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    # =========================train=========================
    train_config = config["train"]
    max_iter = train_config["max_iteration"]
    snap_interval = train_config["snap_interval"]
    save_interval = train_config["save_interval"]
    eval_interval = train_config["eval_interval"]

    global_step = iteration + 1
    iterator = iter(tqdm.tqdm(data_loader))
    downsample_factor = config["model"]["downsample_factor"]
    while global_step <= max_iter:
        try:
            batch = next(iterator)
        except StopIteration as e:
            iterator = iter(tqdm.tqdm(data_loader))
            batch = next(iterator)

        model.train()
        (text_sequences, text_lengths, text_positions, mel_specs, lin_specs,
         frames, decoder_positions, done_flags) = batch
        downsampled_mel_specs = F.strided_slice(
            mel_specs,
            axes=[1],
            starts=[0],
            ends=[mel_specs.shape[1]],
            strides=[downsample_factor])
        outputs = model(
            text_sequences,
            text_positions,
            text_lengths,
            None,
            downsampled_mel_specs,
            decoder_positions, )
        # mel_outputs, linear_outputs, alignments, done
        inputs = (downsampled_mel_specs, lin_specs, done_flags, text_lengths,
                  frames)
        losses = criterion(outputs, inputs)

        l = losses["loss"]
        if env.nranks > 1:
            l = model.scale_loss(l)
            l.backward()
            model.apply_collective_grads()
140 141
        else:
            l.backward()
142

143 144
        # record learning rate before updating
        if env.local_rank == 0:
145 146
            writer.add_scalar("learning_rate",
                              optim._learning_rate.step().numpy(), global_step)
147 148 149 150 151
        optim.minimize(l)
        optim.clear_gradients()

        # record step losses
        step_loss = {k: v.numpy()[0] for k, v in losses.items()}
C
chenfeiyu 已提交
152

153 154
        if env.local_rank == 0:
            tqdm.tqdm.write("[Train] global_step: {}\tloss: {}".format(
155 156 157
                global_step, step_loss["loss"]))
            for k, v in step_loss.items():
                writer.add_scalar(k, v, global_step)
C
chenfeiyu 已提交
158

159 160 161 162
        # train state saving, the first sentence in the batch
        if env.local_rank == 0 and global_step % snap_interval == 0:
            input_specs = (mel_specs, lin_specs)
            state_saver(outputs, input_specs, global_step)
C
chenfeiyu 已提交
163

164 165 166
        # evaluation
        if env.local_rank == 0 and global_step % eval_interval == 0:
            evaluator(model, global_step)
C
chenfeiyu 已提交
167

168 169 170
        # save checkpoint
        if env.local_rank == 0 and global_step % save_interval == 0:
            save_parameters(ckpt_dir, global_step, model, optim)
C
chenfeiyu 已提交
171

172
        global_step += 1