train_transformer.py 8.0 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15
import os
from tqdm import tqdm
走神的阿圆's avatar
走神的阿圆 已提交
16
from visualdl import LogWriter
L
lifuchen 已提交
17
from collections import OrderedDict
L
lifuchen 已提交
18
import argparse
L
lifuchen 已提交
19
from pprint import pprint
L
lifuchen 已提交
20
from ruamel import yaml
L
lifuchen 已提交
21
from matplotlib import cm
L
lifuchen 已提交
22 23
import numpy as np
import paddle.fluid as fluid
L
lifuchen 已提交
24 25
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
26
from parakeet.models.transformer_tts.utils import cross_entropy
L
lifuchen 已提交
27
from data import LJSpeechLoader
28 29
from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import io
L
lifuchen 已提交
30

L
lifuchen 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def add_config_options_to_parser(parser):
    parser.add_argument("--config", type=str, help="path of the config file")
    parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
    parser.add_argument("--data", type=str, help="path of LJspeech dataset")

    g = parser.add_mutually_exclusive_group()
    g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
    g.add_argument(
        "--iteration",
        type=int,
        help="the iteration of the checkpoint to load from output directory")

    parser.add_argument(
        "--output",
        type=str,
        default="experiment",
        help="path to save experiment results")
L
lifuchen 已提交
49

L
lifuchen 已提交
50

L
lifuchen 已提交
51
def main(args):
52 53 54
    local_rank = dg.parallel.Env().local_rank
    nranks = dg.parallel.Env().nranks
    parallel = nranks > 1
L
lifuchen 已提交
55

56
    with open(args.config) as f:
L
lifuchen 已提交
57
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
58 59

    global_step = 0
60
    place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
L
lifuchen 已提交
61

62 63
    if not os.path.exists(args.output):
        os.mkdir(args.output)
L
lifuchen 已提交
64

走神的阿圆's avatar
走神的阿圆 已提交
65 66
    writer = LogWriter(os.path.join(args.output,
                                    'log')) if local_rank == 0 else None
L
lifuchen 已提交
67

68 69 70 71 72 73 74 75 76 77 78
    fluid.enable_dygraph(place)
    network_cfg = cfg['network']
    model = TransformerTTS(
        network_cfg['embedding_size'], network_cfg['hidden_size'],
        network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
        cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
        network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])

    model.train()
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
79
                                        (cfg['train']['learning_rate']**2)),
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                                   cfg['train']['warm_up_step']),
        parameter_list=model.parameters(),
        grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
            'grad_clip_thresh']))

    # Load parameters.
    global_step = io.load_parameters(
        model=model,
        optimizer=optimizer,
        checkpoint_dir=os.path.join(args.output, 'checkpoints'),
        iteration=args.iteration,
        checkpoint_path=args.checkpoint)
    print("Rank {}: checkpoint loaded.".format(local_rank))

    if parallel:
        strategy = dg.parallel.prepare_context()
        model = fluid.dygraph.parallel.DataParallel(model, strategy)

    reader = LJSpeechLoader(
        cfg['audio'],
        place,
        args.data,
        cfg['train']['batch_size'],
        nranks,
        local_rank,
105
        shuffle=True).reader
106

107
    iterator = iter(tqdm(reader))
108

109
    global_step += 1
110

111 112 113 114 115 116
    while global_step <= cfg['train']['max_iteration']:
        try:
            batch = next(iterator)
        except StopIteration as e:
            iterator = iter(tqdm(reader))
            batch = next(iterator)
117

118
        character, mel, mel_input, pos_text, pos_mel, stop_tokens = batch
119

120 121 122 123 124 125 126 127
        mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
            character, mel_input, pos_text, pos_mel)

        mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(mel_pred, mel)))
        post_mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(postnet_pred, mel)))
        loss = mel_loss + post_mel_loss
L
lifuchen 已提交
128

129 130 131
        stop_loss = cross_entropy(
            stop_preds, stop_tokens, weight=cfg['network']['stop_loss_weight'])
        loss = loss + stop_loss
132 133

        if local_rank == 0:
134 135 136 137 138 139
            writer.add_scalar('training_loss/mel_loss',
                              mel_loss.numpy(),
                              global_step)
            writer.add_scalar('training_loss/post_mel_loss',
                              post_mel_loss.numpy(),
                              global_step)
140
            writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
141 142

            if parallel:
143 144 145 146 147 148
                writer.add_scalar('alphas/encoder_alpha',
                                   model._layers.encoder.alpha.numpy(),
                                   global_step)
                writer.add_scalar('alphas/decoder_alpha',
                                   model._layers.decoder.alpha.numpy(),
                                   global_step)
149
            else:
150 151 152 153 154 155
                writer.add_scalar('alphas/encoder_alpha',
                                   model.encoder.alpha.numpy(),
                                   global_step)
                writer.add_scalar('alphas/decoder_alpha',
                                   model.decoder.alpha.numpy(),
                                   global_step)
156 157 158 159 160 161 162 163 164 165 166 167 168 169

            writer.add_scalar('learning_rate',
                              optimizer._learning_rate.step().numpy(),
                              global_step)

            if global_step % cfg['train']['image_interval'] == 1:
                for i, prob in enumerate(attn_probs):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_%d_0' % global_step,
                            x,
走神的阿圆's avatar
走神的阿圆 已提交
170
                            i * 4 + j)
171 172 173 174 175 176 177 178 179

                for i, prob in enumerate(attn_enc):
                    for j in range(cfg['network']['encoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_enc_%d_0' % global_step,
                            x,
走神的阿圆's avatar
走神的阿圆 已提交
180
                            i * 4 + j)
181 182 183 184 185 186 187 188 189

                for i, prob in enumerate(attn_dec):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_dec_%d_0' % global_step,
                            x,
走神的阿圆's avatar
走神的阿圆 已提交
190
                            i * 4 + j)
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207

        if parallel:
            loss = model.scale_loss(loss)
            loss.backward()
            model.apply_collective_grads()
        else:
            loss.backward()
        optimizer.minimize(loss)
        model.clear_gradients()

        # save checkpoint
        if local_rank == 0 and global_step % cfg['train'][
                'checkpoint_interval'] == 0:
            io.save_parameters(
                os.path.join(args.output, 'checkpoints'), global_step, model,
                optimizer)
        global_step += 1
208 209 210

    if local_rank == 0:
        writer.close()
L
lifuchen 已提交
211

L
lifuchen 已提交
212 213

if __name__ == '__main__':
L
lifuchen 已提交
214
    parser = argparse.ArgumentParser(description="Train TransformerTTS model")
L
lifuchen 已提交
215
    add_config_options_to_parser(parser)
L
lifuchen 已提交
216 217
    args = parser.parse_args()
    # Print the whole config setting.
218
    pprint(vars(args))
219
    main(args)