train.py 5.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6
import numpy as np
import argparse
import os
import time
import math
from pathlib import Path
L
lifuchen 已提交
7 8
from parse import add_config_options_to_parser
from pprint import pprint
L
lifuchen 已提交
9
from ruamel import yaml
L
lifuchen 已提交
10
from tqdm import tqdm
L
lifuchen 已提交
11
from collections import OrderedDict
L
lifuchen 已提交
12 13 14 15
from tensorboardX import SummaryWriter
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
L
lifuchen 已提交
16
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
L
lifuchen 已提交
17 18
from parakeet.models.fastspeech.fastspeech import FastSpeech
from parakeet.models.fastspeech.utils import get_alignment
L
lifuchen 已提交
19 20 21
import sys
sys.path.append("../transformer_tts")
from data import LJSpeechLoader
L
lifuchen 已提交
22

L
lifuchen 已提交
23 24 25 26 27 28
def load_checkpoint(step, model_path):
    model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
    new_state_dict = OrderedDict()
    for param in model_dict:
        if param.startswith('_layers.'):
            new_state_dict[param[8:]] = model_dict[param]
L
lifuchen 已提交
29
        else:
L
lifuchen 已提交
30 31
            new_state_dict[param] = model_dict[param]
    return new_state_dict, opti_dict
L
lifuchen 已提交
32

L
lifuchen 已提交
33 34 35
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
L
lifuchen 已提交
36

L
lifuchen 已提交
37 38
    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
39 40 41

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
L
lifuchen 已提交
42 43
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())
L
lifuchen 已提交
44

L
lifuchen 已提交
45 46 47
    if not os.path.exists(args.log_dir):
            os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir,'fastspeech')
L
lifuchen 已提交
48 49 50 51

    writer = SummaryWriter(path) if local_rank == 0 else None

    with dg.guard(place):
L
lifuchen 已提交
52 53
        with fluid.unique_name.guard():
            transformerTTS = TransformerTTS(cfg)
L
lifuchen 已提交
54
            model_dict, _ = load_checkpoint(str(args.transformer_step), os.path.join(args.transtts_path, "transformer"))
L
lifuchen 已提交
55 56
            transformerTTS.set_dict(model_dict)
            transformerTTS.eval()
L
lifuchen 已提交
57 58 59

        model = FastSpeech(cfg)
        model.train()
L
lifuchen 已提交
60
        optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg['warm_up_step'] *( args.lr ** 2)), cfg['warm_up_step']),
61
                                                  parameter_list=model.parameters())
L
lifuchen 已提交
62
        reader = LJSpeechLoader(cfg, args, nranks, local_rank, shuffle=True).reader()
L
lifuchen 已提交
63
        
L
lifuchen 已提交
64 65
        if args.checkpoint_path is not None:
            model_dict, opti_dict = load_checkpoint(str(args.fastspeech_step), os.path.join(args.checkpoint_path, "fastspeech"))
L
lifuchen 已提交
66 67
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
L
lifuchen 已提交
68
            global_step = args.fastspeech_step
L
lifuchen 已提交
69 70
            print("load checkpoint!!!")

L
lifuchen 已提交
71
        if args.use_data_parallel:
L
lifuchen 已提交
72
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
73
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
L
lifuchen 已提交
74

L
lifuchen 已提交
75
        for epoch in range(args.epochs):
L
lifuchen 已提交
76 77 78 79
            pbar = tqdm(reader)

            for i, data in enumerate(pbar):
                pbar.set_description('Processing at epoch %d'%epoch)
L
lifuchen 已提交
80
                character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens = data
L
lifuchen 已提交
81 82

                _, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel)
L
lifuchen 已提交
83
                alignment = dg.to_variable(get_alignment(attn_probs, mel_lens, cfg['transformer_head'])).astype(np.float32)
L
lifuchen 已提交
84

L
lifuchen 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                global_step += 1
                    
                #Forward
                result= model(character, 
                              pos_text, 
                              mel_pos=pos_mel,  
                              length_target=alignment)
                mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
                mel_loss = layers.mse_loss(mel_output, mel)
                mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
                duration_loss = layers.mean(layers.abs(layers.elementwise_sub(duration_predictor_output, alignment)))
                total_loss = mel_loss + mel_postnet_loss + duration_loss

                if local_rank==0:
                    writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
                    writer.add_scalar('post_mel_loss', mel_postnet_loss.numpy(), global_step)
                    writer.add_scalar('duration_loss', duration_loss.numpy(), global_step)
                    writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)


L
lifuchen 已提交
105
                if args.use_data_parallel:
L
lifuchen 已提交
106 107 108 109 110
                    total_loss = model.scale_loss(total_loss)
                    total_loss.backward()
                    model.apply_collective_grads()
                else:
                    total_loss.backward()
L
lifuchen 已提交
111
                optimizer.minimize(total_loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg['grad_clip_thresh']))
L
lifuchen 已提交
112 113 114
                model.clear_gradients()

                 # save checkpoint
L
lifuchen 已提交
115 116 117 118
                if local_rank==0 and global_step % args.save_step == 0:
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
                    save_path = os.path.join(args.save_path,'fastspeech/%d' % global_step)
L
lifuchen 已提交
119 120 121 122 123 124 125
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
        if local_rank==0:
            writer.close()


if __name__ =='__main__':
L
lifuchen 已提交
126
    parser = argparse.ArgumentParser(description="Train Fastspeech model")
L
lifuchen 已提交
127
    add_config_options_to_parser(parser)
L
lifuchen 已提交
128 129 130 131
    args = parser.parse_args()
    # Print the whole config setting.
    pprint(args)
    main(args)