waveflow_modules.py 16.4 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18
import itertools
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
19
from parakeet.modules import weight_norm
20 21


22 23 24
def get_param_attr(layer_type, filter_size, c_in=1):
    if layer_type == "weight_norm":
        k = np.sqrt(1.0 / (c_in * np.prod(filter_size)))
25 26
        weight_init = fluid.initializer.UniformInitializer(low=-k, high=k)
        bias_init = fluid.initializer.UniformInitializer(low=-k, high=k)
27
    elif layer_type == "common":
28 29 30 31 32
        weight_init = fluid.initializer.ConstantInitializer(0.0)
        bias_init = fluid.initializer.ConstantInitializer(0.0)
    else:
        raise TypeError("Unsupported layer type.")

33 34 35
    param_attr = fluid.ParamAttr(initializer=weight_init)
    bias_attr = fluid.ParamAttr(initializer=bias_init)
    return param_attr, bias_attr
36 37 38


def unfold(x, n_group):
39
    length = x.shape[-1]
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    new_shape = x.shape[:-1] + [length // n_group, n_group]
    return fluid.layers.reshape(x, new_shape)


class WaveFlowLoss:
    def __init__(self, sigma=1.0):
        self.sigma = sigma

    def __call__(self, model_output):
        z, log_s_list = model_output
        for i, log_s in enumerate(log_s_list):
            if i == 0:
                log_s_total = fluid.layers.reduce_sum(log_s)
            else:
                log_s_total = log_s_total + fluid.layers.reduce_sum(log_s)

        loss = fluid.layers.reduce_sum(z * z) / (2 * self.sigma * self.sigma) \
            - log_s_total
        loss = loss / np.prod(z.shape)
        const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)

        return loss + const


class Conditioner(dg.Layer):
65
    def __init__(self, dtype):
66
        super(Conditioner, self).__init__()
67
        upsample_factors = [16, 16]
68

69 70 71
        self.upsample_conv2d = []
        for s in upsample_factors:
            in_channel = 1
72 73 74 75
            param_attr, bias_attr = get_param_attr(
                "weight_norm", (3, 2 * s), c_in=in_channel)
            conv_trans2d = weight_norm.Conv2DTranspose(
                num_channels=in_channel,
76 77 78
                num_filters=1,
                filter_size=(3, 2 * s),
                padding=(1, s // 2),
79 80
                stride=(1, s),
                param_attr=param_attr,
81
                bias_attr=bias_attr,
Y
Yibing Liu 已提交
82
                dtype=dtype)
83 84 85 86 87 88 89 90
            self.upsample_conv2d.append(conv_trans2d)

        for i, layer in enumerate(self.upsample_conv2d):
            self.add_sublayer("conv2d_transpose_{}".format(i), layer)

    def forward(self, x):
        x = fluid.layers.unsqueeze(x, 1)
        for layer in self.upsample_conv2d:
91 92
            x = layer(x)
            x = fluid.layers.leaky_relu(x, alpha=0.4)
93

L
liuyibing01 已提交
94
        return fluid.layers.squeeze(x, [1])
95

K
Kexin Zhao 已提交
96 97 98 99 100 101 102 103
    def infer(self, x):
        x = fluid.layers.unsqueeze(x, 1)
        for layer in self.upsample_conv2d:
            x = layer(x)
            # Trim conv artifacts.
            time_cutoff = layer._filter_size[1] - layer._stride[1]
            x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)

L
liuyibing01 已提交
104
        return fluid.layers.squeeze(x, [1])
K
Kexin Zhao 已提交
105

106 107

class Flow(dg.Layer):
108 109
    def __init__(self, config):
        super(Flow, self).__init__()
110 111 112 113
        self.n_layers = config.n_layers
        self.n_channels = config.n_channels
        self.kernel_h = config.kernel_h
        self.kernel_w = config.kernel_w
114
        self.dtype = "float16" if config.use_fp16 else "float32"
115 116 117

        # Transform audio: [batch, 1, n_group, time/n_group] 
        # => [batch, n_channels, n_group, time/n_group]
118
        param_attr, bias_attr = get_param_attr("weight_norm", (1, 1), c_in=1)
119
        self.start = weight_norm.Conv2D(
120
            num_channels=1,
121
            num_filters=self.n_channels,
122 123
            filter_size=(1, 1),
            param_attr=param_attr,
124 125
            bias_attr=bias_attr,
            dtype=self.dtype)
126 127 128 129

        # Initializing last layer to 0 makes the affine coupling layers
        # do nothing at first.  This helps with training stability
        # output shape: [batch, 2, n_group, time/n_group]
130 131
        param_attr, bias_attr = get_param_attr(
            "common", (1, 1), c_in=self.n_channels)
132
        self.end = dg.Conv2D(
133
            num_channels=self.n_channels,
134
            num_filters=2,
135 136
            filter_size=(1, 1),
            param_attr=param_attr,
137 138
            bias_attr=bias_attr,
            dtype=self.dtype)
139 140

        # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
141 142 143 144 145 146 147
        dilation_dict = {
            8: [1, 1, 1, 1, 1, 1, 1, 1],
            16: [1, 1, 1, 1, 1, 1, 1, 1],
            32: [1, 2, 4, 1, 2, 4, 1, 2],
            64: [1, 2, 4, 8, 16, 1, 2, 4],
            128: [1, 2, 4, 8, 16, 32, 64, 1]
        }
148 149 150 151 152 153 154
        self.dilation_h_list = dilation_dict[config.n_group]

        self.in_layers = []
        self.cond_layers = []
        self.res_skip_layers = []
        for i in range(self.n_layers):
            dilation_h = self.dilation_h_list[i]
155
            dilation_w = 2**i
156

157 158 159
            param_attr, bias_attr = get_param_attr(
                "weight_norm", (self.kernel_h, self.kernel_w),
                c_in=self.n_channels)
160
            in_layer = weight_norm.Conv2D(
161
                num_channels=self.n_channels,
162 163
                num_filters=2 * self.n_channels,
                filter_size=(self.kernel_h, self.kernel_w),
164 165
                dilation=(dilation_h, dilation_w),
                param_attr=param_attr,
166 167
                bias_attr=bias_attr,
                dtype=self.dtype)
168 169
            self.in_layers.append(in_layer)

170 171
            param_attr, bias_attr = get_param_attr(
                "weight_norm", (1, 1), c_in=config.mel_bands)
172
            cond_layer = weight_norm.Conv2D(
173
                num_channels=config.mel_bands,
174
                num_filters=2 * self.n_channels,
175 176
                filter_size=(1, 1),
                param_attr=param_attr,
177 178
                bias_attr=bias_attr,
                dtype=self.dtype)
179 180 181 182 183 184
            self.cond_layers.append(cond_layer)

            if i < self.n_layers - 1:
                res_skip_channels = 2 * self.n_channels
            else:
                res_skip_channels = self.n_channels
185 186
            param_attr, bias_attr = get_param_attr(
                "weight_norm", (1, 1), c_in=self.n_channels)
187
            res_skip_layer = weight_norm.Conv2D(
188
                num_channels=self.n_channels,
189
                num_filters=res_skip_channels,
190 191
                filter_size=(1, 1),
                param_attr=param_attr,
192 193
                bias_attr=bias_attr,
                dtype=self.dtype)
194 195 196 197 198 199 200 201 202 203 204 205 206
            self.res_skip_layers.append(res_skip_layer)

            self.add_sublayer("in_layer_{}".format(i), in_layer)
            self.add_sublayer("cond_layer_{}".format(i), cond_layer)
            self.add_sublayer("res_skip_layer_{}".format(i), res_skip_layer)

    def forward(self, audio, mel):
        # audio: [bs, 1, n_group, time/group]
        # mel: [bs, mel_bands, n_group, time/n_group]
        audio = self.start(audio)

        for i in range(self.n_layers):
            dilation_h = self.dilation_h_list[i]
207
            dilation_w = 2**i
208 209 210 211

            # Pad height dim (n_group): causal convolution
            # Pad width dim (time): dialated non-causal convolution
            pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0
212
            pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
L
liuyibing01 已提交
213
            # Using pad2d is a bit faster than using padding in Conv2D directly 
L
liuyibing01 已提交
214 215 216
            audio_pad = fluid.layers.pad2d(
                audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
            hidden = self.in_layers[i](audio_pad)
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
            cond_hidden = self.cond_layers[i](mel)
            in_acts = hidden + cond_hidden
            out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
                fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
            res_skip_acts = self.res_skip_layers[i](out_acts)

            if i < self.n_layers - 1:
                audio += res_skip_acts[:, :self.n_channels, :, :]
                skip_acts = res_skip_acts[:, self.n_channels:, :, :]
            else:
                skip_acts = res_skip_acts

            if i == 0:
                output = skip_acts
            else:
                output += skip_acts

        return self.end(output)

K
Kexin Zhao 已提交
236 237 238 239 240
    def infer(self, audio, mel, queues):
        audio = self.start(audio)

        for i in range(self.n_layers):
            dilation_h = self.dilation_h_list[i]
241
            dilation_w = 2**i
K
Kexin Zhao 已提交
242 243 244 245 246 247 248 249 250

            state_size = dilation_h * (self.kernel_h - 1)
            queue = queues[i]

            if len(queue) == 0:
                for j in range(state_size):
                    queue.append(fluid.layers.zeros_like(audio))

            state = queue[0:state_size]
251
            state = fluid.layers.concat(state + [audio], axis=2)
K
Kexin Zhao 已提交
252 253 254 255 256 257 258

            queue.pop(0)
            queue.append(audio)

            # Pad height dim (n_group): causal convolution
            # Pad width dim (time): dialated non-causal convolution
            pad_top, pad_bottom = 0, 0
259 260
            pad_left = int((self.kernel_w - 1) * dilation_w / 2)
            pad_right = int((self.kernel_w - 1) * dilation_w / 2)
L
liuyibing01 已提交
261 262
            state = fluid.layers.pad2d(
                state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
K
Kexin Zhao 已提交
263 264 265 266
            hidden = self.in_layers[i](state)
            cond_hidden = self.cond_layers[i](mel)
            in_acts = hidden + cond_hidden
            out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
267
                      fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
K
Kexin Zhao 已提交
268
            res_skip_acts = self.res_skip_layers[i](out_acts)
269

K
Kexin Zhao 已提交
270 271 272 273 274 275 276 277 278 279 280 281
            if i < self.n_layers - 1:
                audio += res_skip_acts[:, :self.n_channels, :, :]
                skip_acts = res_skip_acts[:, self.n_channels:, :, :]
            else:
                skip_acts = res_skip_acts

            if i == 0:
                output = skip_acts
            else:
                output += skip_acts

        return self.end(output)
K
Kexin Zhao 已提交
282 283


284
class WaveFlowModule(dg.Layer):
K
Kexin Zhao 已提交
285 286 287 288 289 290 291 292
    """WaveFlow model implementation.

    Args:
        config (obj): model configuration parameters.

    Returns:
        WaveFlowModule
    """
L
liuyibing01 已提交
293

294 295
    def __init__(self, config):
        super(WaveFlowModule, self).__init__()
296 297
        self.n_flows = config.n_flows
        self.n_group = config.n_group
K
Kexin Zhao 已提交
298
        self.n_layers = config.n_layers
299
        assert self.n_group % 2 == 0
K
Kexin Zhao 已提交
300
        assert self.n_flows % 2 == 0
301

302 303
        self.dtype = "float16" if config.use_fp16 else "float32"
        self.conditioner = Conditioner(self.dtype)
304 305
        self.flows = []
        for i in range(self.n_flows):
306
            flow = Flow(config)
307
            self.flows.append(flow)
308
            self.add_sublayer("flow_{}".format(i), flow)
309

K
Kexin Zhao 已提交
310 311 312 313 314 315 316 317 318 319
        self.perms = []
        half = self.n_group // 2
        for i in range(self.n_flows):
            perm = list(range(self.n_group))
            if i < self.n_flows // 2:
                perm = perm[::-1]
            else:
                perm[:half] = reversed(perm[:half])
                perm[half:] = reversed(perm[half:])
            self.perms.append(perm)
320

321
    def forward(self, audio, mel):
K
Kexin Zhao 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
        """Training forward pass.

        Use a conditioner to upsample mel spectrograms into hidden states.
        These hidden states along with the audio are passed to a stack of Flow
        modules to obtain the final latent variable z and a list of log scaling
        variables, which are then passed to the WaveFlowLoss module to calculate
        the negative log likelihood.

        Args:
            audio (obj): audio samples.
            mel (obj): mel spectrograms.

        Returns:
            z (obj): latent variable.
            log_s_list(list): list of log scaling variables.
        """
338 339 340
        mel = self.conditioner(mel)
        assert mel.shape[2] >= audio.shape[1]
        # Prune out the tail of audio/mel so that time/n_group == 0.
Y
Yibing Liu 已提交
341
        pruned_len = int(audio.shape[1] // self.n_group * self.n_group)
342 343 344 345 346

        if audio.shape[1] > pruned_len:
            audio = audio[:, :pruned_len]
        if mel.shape[2] > pruned_len:
            mel = mel[:, :, :pruned_len]
347

K
Kexin Zhao 已提交
348
        # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
        mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
        # From [bs, time] to [bs, n_group, time/n_group]
        audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
        # [bs, 1, n_group, time/n_group] 
        audio = fluid.layers.unsqueeze(audio, 1)
        log_s_list = []
        for i in range(self.n_flows):
            inputs = audio[:, :, :-1, :]
            conds = mel[:, :, 1:, :]
            outputs = self.flows[i](inputs, conds)
            log_s = outputs[:, :1, :, :]
            b = outputs[:, 1:, :, :]
            log_s_list.append(log_s)

            audio_0 = audio[:, :, :1, :]
            audio_out = audio[:, :, 1:, :] * fluid.layers.exp(log_s) + b
            audio = fluid.layers.concat([audio_0, audio_out], axis=2)

            # Permute over the height dim.
            audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
            audio = fluid.layers.stack(audio_slices, axis=2)
            mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
            mel = fluid.layers.stack(mel_slices, axis=2)

L
liuyibing01 已提交
373
        z = fluid.layers.squeeze(audio, [1])
374 375
        return z, log_s_list

K
Kexin Zhao 已提交
376
    def synthesize(self, mel, sigma=1.0):
K
Kexin Zhao 已提交
377 378 379 380 381 382
        """Use model to synthesize waveform.

        Use a conditioner to upsample mel spectrograms into hidden states.
        These hidden states along with initial random gaussian latent variable
        are passed to a stack of Flow modules to obtain the audio output.

K
Kexin Zhao 已提交
383 384 385 386 387 388
        Note that we use convolutional queue (https://arxiv.org/abs/1611.09482)
        to cache the intermediate hidden states, which will speed up the
        autoregressive inference over the height dimension. Current
        implementation only supports height dimension (self.n_group) equals
        8 or 16, i.e., where there is no dilation on the height dimension.

K
Kexin Zhao 已提交
389 390 391 392 393 394 395 396
        Args:
            mel (obj): mel spectrograms.
            sigma (float, optional): standard deviation of the guassian latent
                variable. Defaults to 1.0.

        Returns:
            audio (obj): synthesized audio.
        """
397 398
        if self.dtype == "float16":
            mel = fluid.layers.cast(mel, self.dtype)
K
Kexin Zhao 已提交
399 400 401 402 403 404
        mel = self.conditioner.infer(mel)
        # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
        mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])

        audio = fluid.layers.gaussian_random(
            shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
405 406
        if self.dtype == "float16":
            audio = fluid.layers.cast(audio, self.dtype)
K
Kexin Zhao 已提交
407 408 409 410 411 412 413 414
        for i in reversed(range(self.n_flows)):
            # Permute over the height dimension.
            audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
            audio = fluid.layers.stack(audio_slices, axis=2)
            mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
            mel = fluid.layers.stack(mel_slices, axis=2)

            audio_list = []
K
Kexin Zhao 已提交
415
            audio_0 = audio[:, :, 0:1, :]
K
Kexin Zhao 已提交
416
            audio_list.append(audio_0)
K
Kexin Zhao 已提交
417 418
            audio_h = audio_0
            queues = [[] for _ in range(self.n_layers)]
K
Kexin Zhao 已提交
419 420

            for h in range(1, self.n_group):
K
Kexin Zhao 已提交
421
                inputs = audio_h
422
                conds = mel[:, :, h:(h + 1), :]
K
Kexin Zhao 已提交
423 424 425 426 427 428
                outputs = self.flows[i].infer(inputs, conds, queues)

                log_s = outputs[:, 0:1, :, :]
                b = outputs[:, 1:, :, :]
                audio_h = (audio[:, :, h:(h+1), :] - b) / \
                    fluid.layers.exp(log_s)
K
Kexin Zhao 已提交
429 430 431 432
                audio_list.append(audio_h)

            audio = fluid.layers.concat(audio_list, axis=2)

K
Kexin Zhao 已提交
433
        # audio: [bs, n_group, time/n_group]
L
liuyibing01 已提交
434
        audio = fluid.layers.squeeze(audio, [1])
K
Kexin Zhao 已提交
435
        # audio: [bs, time]
K
Kexin Zhao 已提交
436
        audio = fluid.layers.reshape(
K
Kexin Zhao 已提交
437
            fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
K
Kexin Zhao 已提交
438
        return audio