waveflow.py 9.4 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20 21
import itertools
import os
import time

import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
K
Kexin Zhao 已提交
22
from scipy.io.wavfile import write
23

L
liuyibing01 已提交
24
from parakeet.utils import io
25
from parakeet.modules import weight_norm
L
liuyibing01 已提交
26 27 28
from parakeet.models.waveflow import WaveFlowLoss, WaveFlowModule
from data import LJSpeech
import utils
29 30 31


class WaveFlow():
K
Kexin Zhao 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44
    """Wrapper class of WaveFlow model that supports multiple APIs.

    This module provides APIs for model building, training, validation,
    inference, benchmarking, and saving.

    Args:
        config (obj): config info.
        checkpoint_dir (str): path for checkpointing.
        parallel (bool, optional): whether use multiple GPUs for training.
            Defaults to False.
        rank (int, optional): the rank of the process in a multi-process
            scenario. Defaults to 0.
        nranks (int, optional): the total number of processes. Defaults to 1.
走神的阿圆's avatar
走神的阿圆 已提交
45
        vdl_logger (obj, optional): logger to visualize metrics.
K
Kexin Zhao 已提交
46 47 48 49 50
            Defaults to None.

    Returns:
        WaveFlow
    """
L
liuyibing01 已提交
51

52 53 54 55 56 57
    def __init__(self,
                 config,
                 checkpoint_dir,
                 parallel=False,
                 rank=0,
                 nranks=1,
走神的阿圆's avatar
走神的阿圆 已提交
58
                 vdl_logger=None):
59 60 61 62 63
        self.config = config
        self.checkpoint_dir = checkpoint_dir
        self.parallel = parallel
        self.rank = rank
        self.nranks = nranks
走神的阿圆's avatar
走神的阿圆 已提交
64
        self.vdl_logger = vdl_logger
65
        self.dtype = "float16" if config.use_fp16 else "float32"
66 67

    def build(self, training=True):
K
Kexin Zhao 已提交
68 69 70 71 72 73 74 75 76
        """Initialize the model.

        Args:
            training (bool, optional): Whether the model is built for training or inference.
                Defaults to True.

        Returns:
            None
        """
77
        config = self.config
78
        dataset = LJSpeech(config, self.nranks, self.rank)
79 80 81
        self.trainloader = dataset.trainloader
        self.validloader = dataset.validloader

82 83
        waveflow = WaveFlowModule(config)

84 85
        if training:
            optimizer = fluid.optimizer.AdamOptimizer(
86 87 88
                learning_rate=config.learning_rate,
                parameter_list=waveflow.parameters())

89
            # Load parameters.
L
liuyibing01 已提交
90 91 92 93
            iteration = io.load_parameters(
                model=waveflow,
                optimizer=optimizer,
                checkpoint_dir=self.checkpoint_dir,
94
                iteration=config.iteration,
L
liuyibing01 已提交
95
                checkpoint_path=config.checkpoint)
96
            print("Rank {}: checkpoint loaded.".format(self.rank))
97

98 99 100 101
            # Data parallelism.
            if self.parallel:
                strategy = dg.parallel.prepare_context()
                waveflow = dg.parallel.DataParallel(waveflow, strategy)
102

103 104 105 106 107 108
            self.waveflow = waveflow
            self.optimizer = optimizer
            self.criterion = WaveFlowLoss(config.sigma)

        else:
            # Load parameters.
L
liuyibing01 已提交
109 110 111
            iteration = io.load_parameters(
                model=waveflow,
                checkpoint_dir=self.checkpoint_dir,
112
                iteration=config.iteration,
L
liuyibing01 已提交
113
                checkpoint_path=config.checkpoint)
114 115
            print("Rank {}: checkpoint loaded.".format(self.rank))

116 117 118 119
            for layer in waveflow.sublayers():
                if isinstance(layer, weight_norm.WeightNormWrapper):
                    layer.remove_weight_norm()

120 121
            self.waveflow = waveflow

L
liuyibing01 已提交
122 123
        return iteration

124
    def train_step(self, iteration):
K
Kexin Zhao 已提交
125 126 127 128 129 130 131 132
        """Train the model for one step.

        Args:
            iteration (int): current iteration number.

        Returns:
            None
        """
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        self.waveflow.train()

        start_time = time.time()
        audios, mels = next(self.trainloader)
        load_time = time.time()

        outputs = self.waveflow(audios, mels)
        loss = self.criterion(outputs)

        if self.parallel:
            # loss = loss / num_trainers
            loss = self.waveflow.scale_loss(loss)
            loss.backward()
            self.waveflow.apply_collective_grads()
        else:
            loss.backward()

150 151
        self.optimizer.minimize(
            loss, parameter_list=self.waveflow.parameters())
152 153 154 155 156 157 158 159 160 161 162 163
        self.waveflow.clear_gradients()

        graph_time = time.time()

        if self.rank == 0:
            loss_val = float(loss.numpy()) * self.nranks
            log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
                  "Time: {:.3f}/{:.3f}".format(
                  self.rank, iteration, loss_val,
                  load_time - start_time, graph_time - load_time)
            print(log)

走神的阿圆's avatar
走神的阿圆 已提交
164 165
            vdl_writer = self.vdl_logger
            vdl_writer.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
166 167 168

    @dg.no_grad
    def valid_step(self, iteration):
K
Kexin Zhao 已提交
169 170 171 172 173 174 175 176
        """Run the model on the validation dataset.

        Args:
            iteration (int): current iteration number.

        Returns:
            None
        """
177
        self.waveflow.eval()
走神的阿圆's avatar
走神的阿圆 已提交
178
        vdl_writer = self.vdl_logger
179 180 181 182 183 184 185 186 187 188 189 190

        total_loss = []
        sample_audios = []
        start_time = time.time()

        for i, batch in enumerate(self.validloader()):
            audios, mels = batch
            valid_outputs = self.waveflow(audios, mels)
            valid_z, valid_log_s_list = valid_outputs

            # Visualize latent z and scale log_s.
            if self.rank == 0 and i == 0:
走神的阿圆's avatar
走神的阿圆 已提交
191 192
                vdl_writer.add_histogram("Valid-Latent_z", valid_z.numpy(),
                                         iteration)
193 194
                for j, valid_log_s in enumerate(valid_log_s_list):
                    hist_name = "Valid-{}th-Flow-Log_s".format(j)
走神的阿圆's avatar
走神的阿圆 已提交
195 196
                    vdl_writer.add_histogram(hist_name, valid_log_s.numpy(),
                                             iteration)
197 198 199 200 201 202 203 204 205 206

            valid_loss = self.criterion(valid_outputs)
            total_loss.append(float(valid_loss.numpy()))

        total_time = time.time() - start_time
        if self.rank == 0:
            loss_val = np.mean(total_loss)
            log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
                self.rank, loss_val, total_time)
            print(log)
走神的阿圆's avatar
走神的阿圆 已提交
207
            vdl_writer.add_scalar("Valid-Avg-Loss", loss_val, iteration)
208 209 210

    @dg.no_grad
    def infer(self, iteration):
K
Kexin Zhao 已提交
211 212 213 214 215 216 217 218
        """Run the model to synthesize audios.

        Args:
            iteration (int): iteration number of the loaded checkpoint.

        Returns:
            None
        """
219 220 221 222 223 224
        self.waveflow.eval()

        config = self.config
        sample = config.sample

        output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
225 226
        if not os.path.exists(output):
            os.makedirs(output)
227

K
Kexin Zhao 已提交
228 229 230
        mels_list = [mels for _, mels in self.validloader()]
        if sample is not None:
            mels_list = [mels_list[sample]]
L
liuyibing01 已提交
231 232
        else:
            sample = 0
233

L
liuyibing01 已提交
234 235 236 237
        for idx, mel in enumerate(mels_list):
            abs_idx = sample + idx
            filename = "{}/valid_{}.wav".format(output, abs_idx)
            print("Synthesize sample {}, save as {}".format(abs_idx, filename))
238

K
Kexin Zhao 已提交
239
            start_time = time.time()
K
Kexin Zhao 已提交
240
            audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
K
Kexin Zhao 已提交
241
            syn_time = time.time() - start_time
242

K
Kexin Zhao 已提交
243 244
            audio = audio[0]
            audio_time = audio.shape[0] / self.config.sample_rate
245 246 247
            print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
                                                                    syn_time))

K
Kexin Zhao 已提交
248
            # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
249
            audio = audio.numpy().astype("float32") * 32768.0
K
Kexin Zhao 已提交
250 251 252
            audio = audio.astype('int16')
            write(filename, config.sample_rate, audio)

K
Kexin Zhao 已提交
253 254
    @dg.no_grad
    def benchmark(self):
K
Kexin Zhao 已提交
255 256 257 258 259 260 261 262
        """Run the model to benchmark synthesis speed.

        Args:
            None

        Returns:
            None
        """
K
Kexin Zhao 已提交
263 264 265 266 267 268 269
        self.waveflow.eval()

        mels_list = [mels for _, mels in self.validloader()]
        mel = fluid.layers.concat(mels_list, axis=2)
        mel = mel[:, :, :864]
        batch_size = 8
        mel = fluid.layers.expand(mel, [batch_size, 1, 1])
K
Kexin Zhao 已提交
270

K
Kexin Zhao 已提交
271 272 273 274 275
        for i in range(10):
            start_time = time.time()
            audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
            print("audio.shape = ", audio.shape)
            syn_time = time.time() - start_time
K
Kexin Zhao 已提交
276

K
Kexin Zhao 已提交
277
            audio_time = audio.shape[1] * batch_size / self.config.sample_rate
278 279
            print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
                                                                    syn_time))
K
Kexin Zhao 已提交
280
            print("{} X real-time".format(audio_time / syn_time))
281 282

    def save(self, iteration):
K
Kexin Zhao 已提交
283 284 285 286 287 288 289 290
        """Save model checkpoint.

        Args:
            iteration (int): iteration number of the model to be saved.

        Returns:
            None
        """
L
liuyibing01 已提交
291 292
        io.save_parameters(self.checkpoint_dir, iteration, self.waveflow,
                           self.optimizer)