train_transformer.py 8.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
L
lifuchen 已提交
17
#from pathlib import Path
L
lifuchen 已提交
18
from collections import OrderedDict
L
lifuchen 已提交
19
import argparse
L
lifuchen 已提交
20 21
from parse import add_config_options_to_parser
from pprint import pprint
L
lifuchen 已提交
22
from ruamel import yaml
L
lifuchen 已提交
23
from matplotlib import cm
L
lifuchen 已提交
24 25
import numpy as np
import paddle.fluid as fluid
L
lifuchen 已提交
26 27
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
28
from parakeet.models.transformer_tts.utils import cross_entropy
L
lifuchen 已提交
29 30
from data import LJSpeechLoader
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
L
lifuchen 已提交
31

L
lifuchen 已提交
32

L
lifuchen 已提交
33
def load_checkpoint(step, model_path):
L
lifuchen 已提交
34 35
    model_dict, opti_dict = fluid.dygraph.load_dygraph(
        os.path.join(model_path, step))
L
lifuchen 已提交
36 37 38 39 40 41 42
    new_state_dict = OrderedDict()
    for param in model_dict:
        if param.startswith('_layers.'):
            new_state_dict[param[8:]] = model_dict[param]
        else:
            new_state_dict[param] = model_dict[param]
    return new_state_dict, opti_dict
L
lifuchen 已提交
43

L
lifuchen 已提交
44

L
lifuchen 已提交
45 46 47
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
L
lifuchen 已提交
48

L
lifuchen 已提交
49 50
    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
51 52 53

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
L
lifuchen 已提交
54 55
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())
L
lifuchen 已提交
56

L
lifuchen 已提交
57
    if not os.path.exists(args.log_dir):
L
lifuchen 已提交
58 59
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'transformer')
L
lifuchen 已提交
60 61

    writer = SummaryWriter(path) if local_rank == 0 else None
L
lifuchen 已提交
62

L
lifuchen 已提交
63
    with dg.guard(place):
L
lifuchen 已提交
64
        model = TransformerTTS(cfg)
L
lifuchen 已提交
65 66

        model.train()
L
lifuchen 已提交
67 68 69 70 71 72 73
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(1 / (
                cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())

        reader = LJSpeechLoader(
            cfg, args, nranks, local_rank, shuffle=True).reader()
L
lifuchen 已提交
74 75

        if args.checkpoint_path is not None:
L
lifuchen 已提交
76 77 78
            model_dict, opti_dict = load_checkpoint(
                str(args.transformer_step),
                os.path.join(args.checkpoint_path, "transformer"))
L
lifuchen 已提交
79 80
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
L
lifuchen 已提交
81
            global_step = args.transformer_step
L
lifuchen 已提交
82 83
            print("load checkpoint!!!")

L
lifuchen 已提交
84
        if args.use_data_parallel:
L
lifuchen 已提交
85
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
86
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
L
lifuchen 已提交
87

L
lifuchen 已提交
88
        for epoch in range(args.epochs):
L
lifuchen 已提交
89
            pbar = tqdm(reader)
L
lifuchen 已提交
90
            for i, data in enumerate(pbar):
L
lifuchen 已提交
91
                pbar.set_description('Processing at epoch %d' % epoch)
92
                character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data
L
lifuchen 已提交
93 94

                global_step += 1
L
lifuchen 已提交
95

96 97 98 99 100 101 102 103 104 105 106
                mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
                    character,
                    mel_input,
                    pos_text,
                    pos_mel,
                    dec_slf_mask=dec_slf_mask,
                    enc_slf_mask=enc_slf_mask,
                    enc_query_mask=enc_query_mask,
                    enc_dec_mask=enc_dec_mask,
                    dec_query_slf_mask=dec_query_slf_mask,
                    dec_query_mask=dec_query_mask)
L
lifuchen 已提交
107 108 109 110 111

                mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(mel_pred, mel)))
                post_mel_loss = layers.mean(
                    layers.abs(layers.elementwise_sub(postnet_pred, mel)))
L
lifuchen 已提交
112
                loss = mel_loss + post_mel_loss
113

L
lifuchen 已提交
114
                # Note: When used stop token loss the learning did not work.
L
lifuchen 已提交
115
                if args.stop_token:
116
                    label = (pos_mel == 0).astype(np.float32)
L
lifuchen 已提交
117 118
                    stop_loss = cross_entropy(stop_preds, label)
                    loss = loss + stop_loss
L
lifuchen 已提交
119

L
lifuchen 已提交
120
                if local_rank == 0:
L
lifuchen 已提交
121
                    writer.add_scalars('training_loss', {
L
lifuchen 已提交
122 123
                        'mel_loss': mel_loss.numpy(),
                        'post_mel_loss': post_mel_loss.numpy()
L
lifuchen 已提交
124
                    }, global_step)
L
lifuchen 已提交
125

L
lifuchen 已提交
126
                    if args.stop_token:
L
lifuchen 已提交
127 128
                        writer.add_scalar('stop_loss',
                                          stop_loss.numpy(), global_step)
L
lifuchen 已提交
129

130 131
                    if args.use_data_parallel:
                        writer.add_scalars('alphas', {
L
lifuchen 已提交
132 133 134 135
                            'encoder_alpha':
                            model._layers.encoder.alpha.numpy(),
                            'decoder_alpha':
                            model._layers.decoder.alpha.numpy(),
136 137 138
                        }, global_step)
                    else:
                        writer.add_scalars('alphas', {
L
lifuchen 已提交
139 140
                            'encoder_alpha': model.encoder.alpha.numpy(),
                            'decoder_alpha': model.decoder.alpha.numpy(),
141
                        }, global_step)
L
lifuchen 已提交
142

L
lifuchen 已提交
143 144 145
                    writer.add_scalar('learning_rate',
                                      optimizer._learning_rate.step().numpy(),
                                      global_step)
L
lifuchen 已提交
146

L
lifuchen 已提交
147
                    if global_step % args.image_step == 1:
L
lifuchen 已提交
148 149
                        for i, prob in enumerate(attn_probs):
                            for j in range(4):
L
lifuchen 已提交
150 151 152 153 154 155 156
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * 16]) * 255)
                                writer.add_image(
                                    'Attention_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")
L
lifuchen 已提交
157

L
lifuchen 已提交
158 159
                        for i, prob in enumerate(attn_enc):
                            for j in range(4):
L
lifuchen 已提交
160 161 162 163 164 165 166
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * 16]) * 255)
                                writer.add_image(
                                    'Attention_enc_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")
L
lifuchen 已提交
167

L
lifuchen 已提交
168 169
                        for i, prob in enumerate(attn_dec):
                            for j in range(4):
L
lifuchen 已提交
170 171 172 173 174 175 176 177
                                x = np.uint8(
                                    cm.viridis(prob.numpy()[j * 16]) * 255)
                                writer.add_image(
                                    'Attention_dec_%d_0' % global_step,
                                    x,
                                    i * 4 + j,
                                    dataformats="HWC")

L
lifuchen 已提交
178
                if args.use_data_parallel:
L
lifuchen 已提交
179 180
                    loss = model.scale_loss(loss)
                    loss.backward()
L
lifuchen 已提交
181
                    model.apply_collective_grads()
L
lifuchen 已提交
182 183
                else:
                    loss.backward()
L
lifuchen 已提交
184 185 186 187
                optimizer.minimize(
                    loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
                        'grad_clip_thresh']))
L
lifuchen 已提交
188
                model.clear_gradients()
L
lifuchen 已提交
189

L
lifuchen 已提交
190
                # save checkpoint
L
lifuchen 已提交
191
                if local_rank == 0 and global_step % args.save_step == 0:
L
lifuchen 已提交
192 193
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
L
lifuchen 已提交
194 195
                    save_path = os.path.join(args.save_path,
                                             'transformer/%d' % global_step)
L
lifuchen 已提交
196 197
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
L
lifuchen 已提交
198
        if local_rank == 0:
L
lifuchen 已提交
199 200
            writer.close()

L
lifuchen 已提交
201 202

if __name__ == '__main__':
L
lifuchen 已提交
203
    parser = argparse.ArgumentParser(description="Train TransformerTTS model")
L
lifuchen 已提交
204
    add_config_options_to_parser(parser)
L
lifuchen 已提交
205 206 207 208

    args = parser.parse_args()
    # Print the whole config setting.
    pprint(args)
209
    main(args)