residual_block.py 5.1 KB
Newer Older
小湉湉's avatar
小湉湉 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
from typing import Tuple

import paddle
import paddle.nn.functional as F
from paddle import nn


class ResidualBlock(nn.Layer):
    """Residual block module in WaveNet."""

    def __init__(
            self,
            kernel_size: int=3,
            residual_channels: int=64,
            gate_channels: int=128,
            skip_channels: int=64,
            aux_channels: int=80,
            global_channels: int=-1,
            dropout_rate: float=0.0,
            dilation: int=1,
            bias: bool=True,
            scale_residual: bool=False, ):
        """Initialize ResidualBlock module.

        Args:
            kernel_size (int): Kernel size of dilation convolution layer.
            residual_channels (int): Number of channels for residual connection.
            skip_channels (int): Number of channels for skip connection.
            aux_channels (int): Number of local conditioning channels.
            dropout (float): Dropout probability.
            dilation (int): Dilation factor.
            bias (bool): Whether to add bias parameter in convolution layers.
            scale_residual (bool): Whether to scale the residual outputs.

        """
        super().__init__()
        self.dropout_rate = dropout_rate
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.scale_residual = scale_residual

        # check
        assert (
            kernel_size - 1) % 2 == 0, "Not support even number kernel size."
        assert gate_channels % 2 == 0

        # dilation conv
        padding = (kernel_size - 1) // 2 * dilation
        self.conv = nn.Conv1D(
            residual_channels,
            gate_channels,
            kernel_size,
            padding=padding,
            dilation=dilation,
            bias_attr=bias, )

        # local conditioning
        if aux_channels > 0:
            self.conv1x1_aux = nn.Conv1D(
                aux_channels, gate_channels, kernel_size=1, bias_attr=False)
        else:
            self.conv1x1_aux = None

        # global conditioning
        if global_channels > 0:
            self.conv1x1_glo = nn.Conv1D(
                global_channels, gate_channels, kernel_size=1, bias_attr=False)
        else:
            self.conv1x1_glo = None

        # conv output is split into two groups
        gate_out_channels = gate_channels // 2

        # NOTE: concat two convs into a single conv for the efficiency
        #   (integrate res 1x1 + skip 1x1 convs)
        self.conv1x1_out = nn.Conv1D(
            gate_out_channels,
            residual_channels + skip_channels,
            kernel_size=1,
            bias_attr=bias)

    def forward(
            self,
            x: paddle.Tensor,
            x_mask: Optional[paddle.Tensor]=None,
            c: Optional[paddle.Tensor]=None,
            g: Optional[paddle.Tensor]=None,
    ) -> Tuple[paddle.Tensor, paddle.Tensor]:
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, residual_channels, T).
            x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
            c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
            g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).

        Returns:
            Tensor: Output tensor for residual connection (B, residual_channels, T).
            Tensor: Output tensor for skip connection (B, skip_channels, T).

        """
        residual = x
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv(x)

        # split into two part for gated activation
        splitdim = 1
        xa, xb = paddle.split(x, 2, axis=splitdim)

        # local conditioning
        if c is not None:
            c = self.conv1x1_aux(c)
            ca, cb = paddle.split(c, 2, axis=splitdim)
            xa, xb = xa + ca, xb + cb

        # global conditioning
        if g is not None:
            g = self.conv1x1_glo(g)
            ga, gb = paddle.split(g, 2, axis=splitdim)
            xa, xb = xa + ga, xb + gb

        x = paddle.tanh(xa) * F.sigmoid(xb)

        # residual + skip 1x1 conv
        x = self.conv1x1_out(x)
        if x_mask is not None:
            x = x * x_mask

        # split integrated conv results
        x, s = paddle.split(
            x, [self.residual_channels, self.skip_channels], axis=1)

        # for residual connection
        x = x + residual
        if self.scale_residual:
            x = x * math.sqrt(0.5)

        return x, s