train.py 7.8 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
lifuchen 已提交
14 15 16 17 18 19
import numpy as np
import argparse
import os
import time
import math
from pathlib import Path
L
lifuchen 已提交
20 21
from parse import add_config_options_to_parser
from pprint import pprint
L
lifuchen 已提交
22
from ruamel import yaml
L
lifuchen 已提交
23
from tqdm import tqdm
24
from matplotlib import cm
L
lifuchen 已提交
25
from collections import OrderedDict
L
lifuchen 已提交
26 27 28 29
from tensorboardX import SummaryWriter
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
L
lifuchen 已提交
30
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
L
lifuchen 已提交
31 32
from parakeet.models.fastspeech.fastspeech import FastSpeech
from parakeet.models.fastspeech.utils import get_alignment
L
lifuchen 已提交
33 34 35
import sys
sys.path.append("../transformer_tts")
from data import LJSpeechLoader
L
lifuchen 已提交
36

L
lifuchen 已提交
37

L
lifuchen 已提交
38
def load_checkpoint(step, model_path):
L
lifuchen 已提交
39 40
    model_dict, opti_dict = fluid.dygraph.load_dygraph(
        os.path.join(model_path, step))
L
lifuchen 已提交
41 42 43 44
    new_state_dict = OrderedDict()
    for param in model_dict:
        if param.startswith('_layers.'):
            new_state_dict[param[8:]] = model_dict[param]
L
lifuchen 已提交
45
        else:
L
lifuchen 已提交
46 47
            new_state_dict[param] = model_dict[param]
    return new_state_dict, opti_dict
L
lifuchen 已提交
48

L
lifuchen 已提交
49

L
lifuchen 已提交
50 51 52
def main(args):
    local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
    nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
L
lifuchen 已提交
53

L
lifuchen 已提交
54 55
    with open(args.config_path) as f:
        cfg = yaml.load(f, Loader=yaml.Loader)
L
lifuchen 已提交
56 57 58

    global_step = 0
    place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
L
lifuchen 已提交
59 60
             if args.use_data_parallel else fluid.CUDAPlace(0)
             if args.use_gpu else fluid.CPUPlace())
L
lifuchen 已提交
61

L
lifuchen 已提交
62
    if not os.path.exists(args.log_dir):
L
lifuchen 已提交
63 64
        os.mkdir(args.log_dir)
    path = os.path.join(args.log_dir, 'fastspeech')
L
lifuchen 已提交
65 66 67 68

    writer = SummaryWriter(path) if local_rank == 0 else None

    with dg.guard(place):
L
lifuchen 已提交
69
        with fluid.unique_name.guard():
L
lifuchen 已提交
70
            transformer_tts = TransformerTTS(cfg)
L
lifuchen 已提交
71 72 73
            model_dict, _ = load_checkpoint(
                str(args.transformer_step),
                os.path.join(args.transtts_path, "transformer"))
L
lifuchen 已提交
74 75
            transformer_tts.set_dict(model_dict)
            transformer_tts.eval()
L
lifuchen 已提交
76 77 78

        model = FastSpeech(cfg)
        model.train()
L
lifuchen 已提交
79 80 81 82 83 84 85
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=dg.NoamDecay(1 / (
                cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
            parameter_list=model.parameters())
        reader = LJSpeechLoader(
            cfg, args, nranks, local_rank, shuffle=True).reader()

L
lifuchen 已提交
86
        if args.checkpoint_path is not None:
L
lifuchen 已提交
87 88 89
            model_dict, opti_dict = load_checkpoint(
                str(args.fastspeech_step),
                os.path.join(args.checkpoint_path, "fastspeech"))
L
lifuchen 已提交
90 91
            model.set_dict(model_dict)
            optimizer.set_dict(opti_dict)
L
lifuchen 已提交
92
            global_step = args.fastspeech_step
L
lifuchen 已提交
93 94
            print("load checkpoint!!!")

L
lifuchen 已提交
95
        if args.use_data_parallel:
L
lifuchen 已提交
96
            strategy = dg.parallel.prepare_context()
L
lifuchen 已提交
97
            model = fluid.dygraph.parallel.DataParallel(model, strategy)
L
lifuchen 已提交
98

L
lifuchen 已提交
99
        for epoch in range(args.epochs):
L
lifuchen 已提交
100 101 102
            pbar = tqdm(reader)

            for i, data in enumerate(pbar):
L
lifuchen 已提交
103
                pbar.set_description('Processing at epoch %d' % epoch)
104 105 106
                (character, mel, mel_input, pos_text, pos_mel, text_length,
                 mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
                 enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
L
lifuchen 已提交
107

L
lifuchen 已提交
108
                _, _, attn_probs, _, _, _ = transformer_tts(
109 110 111 112 113 114 115 116 117 118 119 120
                    character,
                    mel_input,
                    pos_text,
                    pos_mel,
                    dec_slf_mask=dec_slf_mask,
                    enc_slf_mask=enc_slf_mask,
                    enc_query_mask=enc_query_mask,
                    enc_dec_mask=enc_dec_mask,
                    dec_query_slf_mask=dec_query_slf_mask,
                    dec_query_mask=dec_query_mask)
                alignment, max_attn = get_alignment(attn_probs, mel_lens,
                                                    cfg['transformer_head'])
121 122
                alignment = dg.to_variable(alignment).astype(np.float32)

123 124 125 126 127 128 129 130
                if local_rank == 0 and global_step % 5 == 1:
                    x = np.uint8(
                        cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
                    writer.add_image(
                        'Attention_%d_0' % global_step,
                        x,
                        0,
                        dataformats="HWC")
L
lifuchen 已提交
131

L
lifuchen 已提交
132
                global_step += 1
L
lifuchen 已提交
133

L
lifuchen 已提交
134
                #Forward
L
lifuchen 已提交
135 136 137 138
                result = model(
                    character,
                    pos_text,
                    mel_pos=pos_mel,
139 140 141 142 143
                    length_target=alignment,
                    enc_non_pad_mask=enc_query_mask,
                    enc_slf_attn_mask=enc_slf_mask,
                    dec_non_pad_mask=dec_query_slf_mask,
                    dec_slf_attn_mask=dec_slf_mask)
L
lifuchen 已提交
144 145 146
                mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
                mel_loss = layers.mse_loss(mel_output, mel)
                mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
L
lifuchen 已提交
147 148 149 150
                duration_loss = layers.mean(
                    layers.abs(
                        layers.elementwise_sub(duration_predictor_output,
                                               alignment)))
L
lifuchen 已提交
151 152
                total_loss = mel_loss + mel_postnet_loss + duration_loss

L
lifuchen 已提交
153 154 155 156 157 158 159 160 161 162
                if local_rank == 0:
                    writer.add_scalar('mel_loss',
                                      mel_loss.numpy(), global_step)
                    writer.add_scalar('post_mel_loss',
                                      mel_postnet_loss.numpy(), global_step)
                    writer.add_scalar('duration_loss',
                                      duration_loss.numpy(), global_step)
                    writer.add_scalar('learning_rate',
                                      optimizer._learning_rate.step().numpy(),
                                      global_step)
L
lifuchen 已提交
163

L
lifuchen 已提交
164
                if args.use_data_parallel:
L
lifuchen 已提交
165 166 167 168 169
                    total_loss = model.scale_loss(total_loss)
                    total_loss.backward()
                    model.apply_collective_grads()
                else:
                    total_loss.backward()
L
lifuchen 已提交
170 171 172 173
                optimizer.minimize(
                    total_loss,
                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
                        'grad_clip_thresh']))
L
lifuchen 已提交
174 175
                model.clear_gradients()

L
lifuchen 已提交
176 177
                # save checkpoint
                if local_rank == 0 and global_step % args.save_step == 0:
L
lifuchen 已提交
178 179
                    if not os.path.exists(args.save_path):
                        os.mkdir(args.save_path)
L
lifuchen 已提交
180 181
                    save_path = os.path.join(args.save_path,
                                             'fastspeech/%d' % global_step)
L
lifuchen 已提交
182 183
                    dg.save_dygraph(model.state_dict(), save_path)
                    dg.save_dygraph(optimizer.state_dict(), save_path)
L
lifuchen 已提交
184
        if local_rank == 0:
L
lifuchen 已提交
185 186 187
            writer.close()


L
lifuchen 已提交
188
if __name__ == '__main__':
L
lifuchen 已提交
189
    parser = argparse.ArgumentParser(description="Train Fastspeech model")
L
lifuchen 已提交
190
    add_config_options_to_parser(parser)
L
lifuchen 已提交
191 192 193 194
    args = parser.parse_args()
    # Print the whole config setting.
    pprint(args)
    main(args)