train_transformer.py 6.6 KB
Newer Older
L
lifuchen 已提交
1 2 3 4
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
from pathlib import Path
L
lifuchen 已提交
5
from collections import OrderedDict
L
lifuchen 已提交
6
import argparse
L
lifuchen 已提交
7 8
from parse import add_config_options_to_parser
from pprint import pprint
L
lifuchen 已提交
9
from ruamel import yaml
L
lifuchen 已提交
10
from matplotlib import cm
L
lifuchen 已提交
11 12
import numpy as np
import paddle.fluid as fluid
L
lifuchen 已提交
13 14
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
L
lifuchen 已提交
15
from parakeet.modules.utils import cross_entropy
L
lifuchen 已提交
16 17
from data import LJSpeechLoader
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
L
lifuchen 已提交
18

L
lifuchen 已提交
19 20
def load_checkpoint(step, model_path):
    model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
L
lifuchen 已提交
21 22 23 24 25 26 27
    new_state_dict = OrderedDict()
    for param in model_dict:
        if param.startswith('_layers.'):
            new_state_dict[param[8:]] = model_dict[param]
        else:
            new_state_dict[param] = model_dict[param]
    return new_state_dict, opti_dict
L
lifuchen 已提交
28

L
lifuchen 已提交
29

L
lifuchen 已提交
30 31 32
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
L
lifuchen 已提交
33

L
lifuchen 已提交
34 35
    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
36 37 38

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
L
lifuchen 已提交
39 40
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())
L
lifuchen 已提交
41

L
lifuchen 已提交
42 43 44
    if not os.path.exists(args.log_dir):
            os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir,'transformer')
L
lifuchen 已提交
45 46 47 48

    writer = SummaryWriter(path) if local_rank == 0 else None
    
    with dg.guard(place):
L
lifuchen 已提交
49
        model = TransformerTTS(cfg)
L
lifuchen 已提交
50 51

        model.train()
L
lifuchen 已提交
52
        optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg['warm_up_step'] *( args.lr ** 2)), cfg['warm_up_step']), 
53
                                                  parameter_list=model.parameters())
L
lifuchen 已提交
54
        
L
lifuchen 已提交
55 56 57 58
        reader = LJSpeechLoader(cfg, args, nranks, local_rank, shuffle=True).reader()

        if args.checkpoint_path is not None:
            model_dict, opti_dict = load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer"))
L
lifuchen 已提交
59 60
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
L
lifuchen 已提交
61
            global_step = args.transformer_step
L
lifuchen 已提交
62 63
            print("load checkpoint!!!")

L
lifuchen 已提交
64
        if args.use_data_parallel:
L
lifuchen 已提交
65
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
66
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
L
lifuchen 已提交
67
        
L
lifuchen 已提交
68
        for epoch in range(args.epochs):
L
lifuchen 已提交
69
            pbar = tqdm(reader)
L
lifuchen 已提交
70 71
            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d'%epoch)
L
lifuchen 已提交
72
                character, mel, mel_input, pos_text, pos_mel, text_length, _ = data
L
lifuchen 已提交
73 74 75

                global_step += 1
                mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
L
lifuchen 已提交
76
                
L
lifuchen 已提交
77

L
lifuchen 已提交
78
                label = (pos_mel == 0).astype(np.float32)
L
lifuchen 已提交
79
                    
L
lifuchen 已提交
80 81
                mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
                post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
L
lifuchen 已提交
82 83
                loss = mel_loss + post_mel_loss
                # Note: When used stop token loss the learning did not work.
L
lifuchen 已提交
84
                if args.stop_token:
L
lifuchen 已提交
85 86
                    stop_loss = cross_entropy(stop_preds, label)
                    loss = loss + stop_loss
L
lifuchen 已提交
87

L
lifuchen 已提交
88 89 90
                if local_rank==0:
                    writer.add_scalars('training_loss', {
                        'mel_loss':mel_loss.numpy(),
L
lifuchen 已提交
91
                        'post_mel_loss':post_mel_loss.numpy()
L
lifuchen 已提交
92
                    }, global_step)
L
lifuchen 已提交
93

L
lifuchen 已提交
94
                    if args.stop_token:
L
lifuchen 已提交
95 96
                        writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)

97 98 99 100 101 102 103 104 105 106
                    if args.use_data_parallel:
                        writer.add_scalars('alphas', {
                            'encoder_alpha':model._layers.encoder.alpha.numpy(),
                            'decoder_alpha':model._layers.decoder.alpha.numpy(),
                        }, global_step)
                    else:
                        writer.add_scalars('alphas', {
                            'encoder_alpha':model.encoder.alpha.numpy(),
                            'decoder_alpha':model.decoder.alpha.numpy(),
                        }, global_step)
L
lifuchen 已提交
107

L
lifuchen 已提交
108
                    writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
L
lifuchen 已提交
109

L
lifuchen 已提交
110
                    if global_step % args.image_step == 1:
L
lifuchen 已提交
111 112 113
                        for i, prob in enumerate(attn_probs):
                            for j in range(4):
                                    x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
L
lifuchen 已提交
114
                                    writer.add_image('Attention_%d_0'%global_step, x, i*4+j, dataformats="HWC")
L
lifuchen 已提交
115

L
lifuchen 已提交
116 117 118 119
                        for i, prob in enumerate(attn_enc):
                            for j in range(4):
                                x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
                                writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
L
lifuchen 已提交
120

L
lifuchen 已提交
121 122 123 124
                        for i, prob in enumerate(attn_dec):
                            for j in range(4):
                                x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
                                writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
125
                                
L
lifuchen 已提交
126
                if args.use_data_parallel:
L
lifuchen 已提交
127 128
                    loss = model.scale_loss(loss)
                    loss.backward()
L
lifuchen 已提交
129
                    model.apply_collective_grads()
L
lifuchen 已提交
130 131
                else:
                    loss.backward()
L
lifuchen 已提交
132
                optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg['grad_clip_thresh']))
L
lifuchen 已提交
133
                model.clear_gradients()
134
                
L
lifuchen 已提交
135
                # save checkpoint
L
lifuchen 已提交
136 137 138 139
                if local_rank==0 and global_step % args.save_step == 0:
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
                    save_path = os.path.join(args.save_path,'transformer/%d' % global_step)
L
lifuchen 已提交
140 141 142 143 144 145 146
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
        if local_rank==0:
            writer.close()
                    

if __name__ =='__main__':
L
lifuchen 已提交
147
    parser = argparse.ArgumentParser(description="Train TransformerTTS model")
L
lifuchen 已提交
148
    add_config_options_to_parser(parser)
L
lifuchen 已提交
149 150 151 152

    args = parser.parse_args()
    # Print the whole config setting.
    pprint(args)
153
    main(args)