train_transformer.py 5.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11
import os
from tqdm import tqdm
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
from network import *
from tensorboardX import SummaryWriter
from pathlib import Path
import jsonargparse
from parse import add_config_options_to_parser
from pprint import pprint
from matplotlib import cm
L
lifuchen 已提交
12
from data import LJSpeechLoader
L
lifuchen 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

class MyDataParallel(dg.parallel.DataParallel):
    """
    A data parallel proxy for model.
    """

    def __init__(self, layers, strategy):
        super(MyDataParallel, self).__init__(layers, strategy)

    def __getattr__(self, key):
        if key in self.__dict__:
            return object.__getattribute__(self, key)
        elif key is "_layers":
            return object.__getattribute__(self, "_sub_layers")["_layers"]
        else:
            return getattr(
                object.__getattribute__(self, "_sub_layers")["_layers"], key)


L
lifuchen 已提交
32 33 34
def main(cfg):
    local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
L
lifuchen 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

    if local_rank == 0:
        # Print the whole config setting.
        pprint(jsonargparse.namespace_to_dict(cfg))

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
             if cfg.use_data_parallel else fluid.CUDAPlace(0)
             if cfg.use_gpu else fluid.CPUPlace())

    if not os.path.exists(cfg.log_dir):
            os.mkdir(cfg.log_dir)
    path = os.path.join(cfg.log_dir,'transformer')

    writer = SummaryWriter(path) if local_rank == 0 else None
    
    with dg.guard(place):
        model = Model('transtts', cfg)

        model.train()
        optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000))
L
lifuchen 已提交
56 57 58
        
        reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
        
L
lifuchen 已提交
59 60 61 62 63 64 65
        if cfg.checkpoint_path is not None:
            model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
            print("load checkpoint!!!")

        if cfg.use_data_parallel:
L
lifuchen 已提交
66
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
67
            model = MyDataParallel(model, strategy)
L
lifuchen 已提交
68
        
L
lifuchen 已提交
69
        for epoch in range(cfg.epochs):
L
lifuchen 已提交
70
            pbar = tqdm(reader)
L
lifuchen 已提交
71 72 73 74 75 76 77 78 79 80 81 82
            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d'%epoch)
                character, mel, mel_input, pos_text, pos_mel, text_length = data

                global_step += 1
                
                mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
        
                mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
                post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
                loss = mel_loss + post_mel_loss

L
lifuchen 已提交
83 84 85 86 87
                if local_rank==0:
                    writer.add_scalars('training_loss', {
                        'mel_loss':mel_loss.numpy(),
                        'post_mel_loss':post_mel_loss.numpy(),
                    }, global_step)
L
lifuchen 已提交
88

L
lifuchen 已提交
89 90 91 92
                    writer.add_scalars('alphas', {
                        'encoder_alpha':model.encoder.alpha.numpy(),
                        'decoder_alpha':model.decoder.alpha.numpy(),
                    }, global_step)
L
lifuchen 已提交
93

L
lifuchen 已提交
94
                    writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
L
lifuchen 已提交
95

L
lifuchen 已提交
96 97 98 99 100
                    if global_step % cfg.image_step == 1:
                        for i, prob in enumerate(attn_probs):
                            for j in range(4):
                                    x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
                                    writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
L
lifuchen 已提交
101

L
lifuchen 已提交
102 103 104 105
                        for i, prob in enumerate(attn_enc):
                            for j in range(4):
                                x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
                                writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
L
lifuchen 已提交
106

L
lifuchen 已提交
107 108 109 110
                        for i, prob in enumerate(attn_dec):
                            for j in range(4):
                                x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
                                writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
L
lifuchen 已提交
111 112

                if cfg.use_data_parallel:
L
lifuchen 已提交
113 114
                    loss = model.scale_loss(loss)
                    loss.backward()
L
lifuchen 已提交
115
                    model.apply_collective_grads()
L
lifuchen 已提交
116 117
                else:
                    loss.backward()
L
lifuchen 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
                optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1))
                model.clear_gradients()

                # save checkpoint
                if local_rank==0 and global_step % cfg.save_step == 0:
                    if not os.path.exists(cfg.save_path):
                        os.mkdir(cfg.save_path)
                    save_path = os.path.join(cfg.save_path,'transformer/%d' % global_step)
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
        if local_rank==0:
            writer.close()
                    

if __name__ =='__main__':
L
lifuchen 已提交
133 134 135 136
    parser = jsonargparse.ArgumentParser(description="Train TransformerTTS model", formatter_class='default_argparse')
    add_config_options_to_parser(parser)
    cfg = parser.parse_args('-c ./config/train_transformer.yaml'.split())
    main(cfg)