train_transformer.py 8.0 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
L
lifuchen 已提交
17
from collections import OrderedDict
L
lifuchen 已提交
18
import argparse
L
lifuchen 已提交
19
from pprint import pprint
L
lifuchen 已提交
20
from ruamel import yaml
L
lifuchen 已提交
21
from matplotlib import cm
L
lifuchen 已提交
22 23
import numpy as np
import paddle.fluid as fluid
L
lifuchen 已提交
24 25
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
26
from parakeet.models.transformer_tts.utils import cross_entropy
L
lifuchen 已提交
27
from data import LJSpeechLoader
28 29
from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import io
L
lifuchen 已提交
30

L
lifuchen 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
def add_config_options_to_parser(parser):
    parser.add_argument("--config", type=str, help="path of the config file")
    parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
    parser.add_argument("--data", type=str, help="path of LJspeech dataset")

    g = parser.add_mutually_exclusive_group()
    g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
    g.add_argument(
        "--iteration",
        type=int,
        help="the iteration of the checkpoint to load from output directory")

    parser.add_argument(
        "--output",
        type=str,
        default="experiment",
        help="path to save experiment results")
L
lifuchen 已提交
49

L
lifuchen 已提交
50

L
lifuchen 已提交
51
def main(args):
52 53 54
    local_rank = dg.parallel.Env().local_rank
    nranks = dg.parallel.Env().nranks
    parallel = nranks > 1
L
lifuchen 已提交
55

56
    with open(args.config) as f:
L
lifuchen 已提交
57
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
58 59

    global_step = 0
60
    place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
L
lifuchen 已提交
61

62 63
    if not os.path.exists(args.output):
        os.mkdir(args.output)
L
lifuchen 已提交
64

65 66
    writer = SummaryWriter(os.path.join(args.output,
                                        'log')) if local_rank == 0 else None
L
lifuchen 已提交
67

68 69 70 71 72 73 74 75 76 77 78
    fluid.enable_dygraph(place)
    network_cfg = cfg['network']
    model = TransformerTTS(
        network_cfg['embedding_size'], network_cfg['hidden_size'],
        network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
        cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
        network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])

    model.train()
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
79
                                        (cfg['train']['learning_rate']**2)),
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                                   cfg['train']['warm_up_step']),
        parameter_list=model.parameters(),
        grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
            'grad_clip_thresh']))

    # Load parameters.
    global_step = io.load_parameters(
        model=model,
        optimizer=optimizer,
        checkpoint_dir=os.path.join(args.output, 'checkpoints'),
        iteration=args.iteration,
        checkpoint_path=args.checkpoint)
    print("Rank {}: checkpoint loaded.".format(local_rank))

    if parallel:
        strategy = dg.parallel.prepare_context()
        model = fluid.dygraph.parallel.DataParallel(model, strategy)

    reader = LJSpeechLoader(
        cfg['audio'],
        place,
        args.data,
        cfg['train']['batch_size'],
        nranks,
        local_rank,
105
        shuffle=True).reader
106

107
    iterator = iter(tqdm(reader))
108

109
    global_step += 1
110

111 112 113 114 115 116
    while global_step <= cfg['train']['max_iteration']:
        try:
            batch = next(iterator)
        except StopIteration as e:
            iterator = iter(tqdm(reader))
            batch = next(iterator)
117

118
        character, mel, mel_input, pos_text, pos_mel = batch
119

120 121 122 123 124 125 126 127
        mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
            character, mel_input, pos_text, pos_mel)

        mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(mel_pred, mel)))
        post_mel_loss = layers.mean(
            layers.abs(layers.elementwise_sub(postnet_pred, mel)))
        loss = mel_loss + post_mel_loss
L
lifuchen 已提交
128

129 130 131 132 133 134 135 136 137 138 139 140 141 142
        # Note: When used stop token loss the learning did not work.
        if cfg['network']['stop_token']:
            label = (pos_mel == 0).astype(np.float32)
            stop_loss = cross_entropy(stop_preds, label)
            loss = loss + stop_loss

        if local_rank == 0:
            writer.add_scalars('training_loss', {
                'mel_loss': mel_loss.numpy(),
                'post_mel_loss': post_mel_loss.numpy()
            }, global_step)

            if cfg['network']['stop_token']:
                writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
143 144

            if parallel:
145 146 147 148
                writer.add_scalars('alphas', {
                    'encoder_alpha': model._layers.encoder.alpha.numpy(),
                    'decoder_alpha': model._layers.decoder.alpha.numpy(),
                }, global_step)
149
            else:
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
                writer.add_scalars('alphas', {
                    'encoder_alpha': model.encoder.alpha.numpy(),
                    'decoder_alpha': model.decoder.alpha.numpy(),
                }, global_step)

            writer.add_scalar('learning_rate',
                              optimizer._learning_rate.step().numpy(),
                              global_step)

            if global_step % cfg['train']['image_interval'] == 1:
                for i, prob in enumerate(attn_probs):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_%d_0' % global_step,
                            x,
                            i * 4 + j,
                            dataformats="HWC")

                for i, prob in enumerate(attn_enc):
                    for j in range(cfg['network']['encoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_enc_%d_0' % global_step,
                            x,
                            i * 4 + j,
                            dataformats="HWC")

                for i, prob in enumerate(attn_dec):
                    for j in range(cfg['network']['decoder_num_head']):
                        x = np.uint8(
                            cm.viridis(prob.numpy()[j * cfg['train'][
                                'batch_size'] // nranks]) * 255)
                        writer.add_image(
                            'Attention_dec_%d_0' % global_step,
                            x,
                            i * 4 + j,
                            dataformats="HWC")

        if parallel:
            loss = model.scale_loss(loss)
            loss.backward()
            model.apply_collective_grads()
        else:
            loss.backward()
        optimizer.minimize(loss)
        model.clear_gradients()

        # save checkpoint
        if local_rank == 0 and global_step % cfg['train'][
                'checkpoint_interval'] == 0:
            io.save_parameters(
                os.path.join(args.output, 'checkpoints'), global_step, model,
                optimizer)
        global_step += 1
209 210 211

    if local_rank == 0:
        writer.close()
L
lifuchen 已提交
212

L
lifuchen 已提交
213 214

if __name__ == '__main__':
L
lifuchen 已提交
215
    parser = argparse.ArgumentParser(description="Train TransformerTTS model")
L
lifuchen 已提交
216
    add_config_options_to_parser(parser)
L
lifuchen 已提交
217 218
    args = parser.parse_args()
    # Print the whole config setting.
219
    pprint(vars(args))
220
    main(args)