train_transformer.py 8.7 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
L
lifuchen 已提交
17
from collections import OrderedDict
L
lifuchen 已提交
18
import argparse
L
lifuchen 已提交
19 20
from parse import add_config_options_to_parser
from pprint import pprint
L
lifuchen 已提交
21
from ruamel import yaml
L
lifuchen 已提交
22
from matplotlib import cm
L
lifuchen 已提交
23 24
import numpy as np
import paddle.fluid as fluid
L
lifuchen 已提交
25 26
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
27
from parakeet.models.transformer_tts.utils import cross_entropy
L
lifuchen 已提交
28 29
from data import LJSpeechLoader
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
L
lifuchen 已提交
30

L
lifuchen 已提交
31

L
lifuchen 已提交
32
def load_checkpoint(step, model_path):
L
lifuchen 已提交
33 34
    model_dict, opti_dict = fluid.dygraph.load_dygraph(
        os.path.join(model_path, step))
L
lifuchen 已提交
35 36 37 38 39 40 41
    new_state_dict = OrderedDict()
    for param in model_dict:
        if param.startswith('_layers.'):
            new_state_dict[param[8:]] = model_dict[param]
        else:
            new_state_dict[param] = model_dict[param]
    return new_state_dict, opti_dict
L
lifuchen 已提交
42

L
lifuchen 已提交
43

L
lifuchen 已提交
44 45 46
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
L
lifuchen 已提交
47

L
lifuchen 已提交
48 49
    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
50 51 52

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
L
lifuchen 已提交
53 54
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())
L
lifuchen 已提交
55

L
lifuchen 已提交
56
    if not os.path.exists(args.log_dir):
L
lifuchen 已提交
57 58
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'transformer')
L
lifuchen 已提交
59 60

    writer = SummaryWriter(path) if local_rank == 0 else None
L
lifuchen 已提交
61

L
lifuchen 已提交
62
    with dg.guard(place):
L
lifuchen 已提交
63
        model = TransformerTTS(cfg)
L
lifuchen 已提交
64 65

        model.train()
L
lifuchen 已提交
66 67 68 69 70
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(1 / (
                cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())

L
lifuchen 已提交
71
        if args.checkpoint_path is not None:
L
lifuchen 已提交
72 73 74
            model_dict, opti_dict = load_checkpoint(
                str(args.transformer_step),
                os.path.join(args.checkpoint_path, "transformer"))
L
lifuchen 已提交
75 76
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
L
lifuchen 已提交
77
            global_step = args.transformer_step
L
lifuchen 已提交
78 79
            print("load checkpoint!!!")

L
lifuchen 已提交
80
        if args.use_data_parallel:
L
lifuchen 已提交
81
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
82
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
L
lifuchen 已提交
83

84 85 86
        reader = LJSpeechLoader(
            cfg, args, nranks, local_rank, shuffle=True).reader()

L
lifuchen 已提交
87
        for epoch in range(args.epochs):
L
lifuchen 已提交
88
            pbar = tqdm(reader)
L
lifuchen 已提交
89
            for i, data in enumerate(pbar):
L
lifuchen 已提交
90
                pbar.set_description('Processing at epoch %d' % epoch)
91
                character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data
L
lifuchen 已提交
92 93

                global_step += 1
L
lifuchen 已提交
94

95 96 97 98 99 100 101 102 103 104 105
                mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                    character,
                    mel_input,
                    pos_text,
                    pos_mel,
                    dec_slf_mask=dec_slf_mask,
                    enc_slf_mask=enc_slf_mask,
                    enc_query_mask=enc_query_mask,
                    enc_dec_mask=enc_dec_mask,
                    dec_query_slf_mask=dec_query_slf_mask,
                    dec_query_mask=dec_query_mask)
L
lifuchen 已提交
106 107 108 109 110

                mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(mel_pred, mel)))
                post_mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(postnet_pred, mel)))
L
lifuchen 已提交
111
                loss = mel_loss + post_mel_loss
112

L
lifuchen 已提交
113
                # Note: When used stop token loss the learning did not work.
L
lifuchen 已提交
114
                if args.stop_token:
115
                    label = (pos_mel == 0).astype(np.float32)
L
lifuchen 已提交
116 117
                    stop_loss = cross_entropy(stop_preds, label)
                    loss = loss + stop_loss
L
lifuchen 已提交
118

L
lifuchen 已提交
119
                if local_rank == 0:
L
lifuchen 已提交
120
                    writer.add_scalars('training_loss', {
L
lifuchen 已提交
121 122
                        'mel_loss': mel_loss.numpy(),
                        'post_mel_loss': post_mel_loss.numpy()
L
lifuchen 已提交
123
                    }, global_step)
L
lifuchen 已提交
124

L
lifuchen 已提交
125
                    if args.stop_token:
L
lifuchen 已提交
126 127
                        writer.add_scalar('stop_loss',
                                          stop_loss.numpy(), global_step)
L
lifuchen 已提交
128

129 130
                    if args.use_data_parallel:
                        writer.add_scalars('alphas', {
L
lifuchen 已提交
131 132 133 134
                            'encoder_alpha':
                            model._layers.encoder.alpha.numpy(),
                            'decoder_alpha':
                            model._layers.decoder.alpha.numpy(),
135 136 137
                        }, global_step)
                    else:
                        writer.add_scalars('alphas', {
L
lifuchen 已提交
138 139
                            'encoder_alpha': model.encoder.alpha.numpy(),
                            'decoder_alpha': model.decoder.alpha.numpy(),
140
                        }, global_step)
L
lifuchen 已提交
141

L
lifuchen 已提交
142 143 144
                    writer.add_scalar('learning_rate',
                                      optimizer._learning_rate.step().numpy(),
                                      global_step)
L
lifuchen 已提交
145

L
lifuchen 已提交
146
                    if global_step % args.image_step == 1:
L
lifuchen 已提交
147 148
                        for i, prob in enumerate(attn_probs):
                            for j in range(4):
L
lifuchen 已提交
149
                                x = np.uint8(
150 151
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
L
lifuchen 已提交
152 153 154 155 156
                                writer.add_image(
                                    'Attention_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")
L
lifuchen 已提交
157

L
lifuchen 已提交
158 159
                        for i, prob in enumerate(attn_enc):
                            for j in range(4):
L
lifuchen 已提交
160
                                x = np.uint8(
161 162
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
L
lifuchen 已提交
163 164 165 166 167
                                writer.add_image(
                                    'Attention_enc_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")
L
lifuchen 已提交
168

L
lifuchen 已提交
169 170
                        for i, prob in enumerate(attn_dec):
                            for j in range(4):
L
lifuchen 已提交
171
                                x = np.uint8(
172 173
                                    cm.viridis(prob.numpy()[j * args.batch_size
                                                            // 2]) * 255)
L
lifuchen 已提交
174 175 176 177 178 179
                                writer.add_image(
                                    'Attention_dec_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")

L
lifuchen 已提交
180
                if args.use_data_parallel:
L
lifuchen 已提交
181 182
                    loss = model.scale_loss(loss)
                    loss.backward()
L
lifuchen 已提交
183
                    model.apply_collective_grads()
L
lifuchen 已提交
184 185
                else:
                    loss.backward()
L
lifuchen 已提交
186 187 188 189
                optimizer.minimize(
                    loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
                        'grad_clip_thresh']))
L
lifuchen 已提交
190
                model.clear_gradients()
L
lifuchen 已提交
191

L
lifuchen 已提交
192
                # save checkpoint
L
lifuchen 已提交
193
                if local_rank == 0 and global_step % args.save_step == 0:
L
lifuchen 已提交
194 195
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
L
lifuchen 已提交
196 197
                    save_path = os.path.join(args.save_path,
                                             'transformer/%d' % global_step)
L
lifuchen 已提交
198 199
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
L
lifuchen 已提交
200
        if local_rank == 0:
L
lifuchen 已提交
201 202
            writer.close()

L
lifuchen 已提交
203 204

if __name__ == '__main__':
L
lifuchen 已提交
205
    parser = argparse.ArgumentParser(description="Train TransformerTTS model")
L
lifuchen 已提交
206
    add_config_options_to_parser(parser)
L
lifuchen 已提交
207 208 209 210

    args = parser.parse_args()
    # Print the whole config setting.
    pprint(args)
211
    main(args)