subsampling.py 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
Hui Zhang 已提交
14
# Modified from wenet(https://github.com/wenet-e2e/wenet)
15 16 17 18 19 20
"""Subsampling layer definition."""
from typing import Tuple

import paddle
from paddle import nn

21 22
from paddlespeech.s2t.modules.embedding import PositionalEncoding
from paddlespeech.s2t.utils.log import Log
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

logger = Log(__name__).getlog()

__all__ = [
    "LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6",
    "Conv2dSubsampling8"
]


class BaseSubsampling(nn.Layer):
    def __init__(self, pos_enc_class: nn.Layer=PositionalEncoding):
        super().__init__()
        self.pos_enc = pos_enc_class
        # window size = (1 + right_context) + (chunk_size -1) * subsampling_rate
        self.right_context = 0
        # stride = subsampling_rate * chunk_size
        self.subsampling_rate = 1

    def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
        return self.pos_enc.position_encoding(offset, size)


class LinearNoSubsampling(BaseSubsampling):
    """Linear transform the input without subsampling."""

    def __init__(self,
                 idim: int,
                 odim: int,
                 dropout_rate: float,
                 pos_enc_class: nn.Layer=PositionalEncoding):
        """Construct an linear object.
        Args:
            idim (int): Input dimension.
            odim (int): Output dimension.
            dropout_rate (float): Dropout rate.
            pos_enc_class (PositionalEncoding): position encoding class
        """
        super().__init__(pos_enc_class)
        self.out = nn.Sequential(
            nn.Linear(idim, odim),
            nn.LayerNorm(odim, epsilon=1e-12),
H
Hui Zhang 已提交
64 65
            nn.Dropout(dropout_rate),
            nn.ReLU(), )
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        self.right_context = 0
        self.subsampling_rate = 1

    def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
                ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Input x.
        Args:
            x (paddle.Tensor): Input tensor (#batch, time, idim).
            x_mask (paddle.Tensor): Input mask (#batch, 1, time).
            offset (int): position encoding offset.
        Returns:
            paddle.Tensor: linear input tensor (#batch, time', odim),
                where time' = time .
            paddle.Tensor: positional encoding
            paddle.Tensor: linear input mask (#batch, 1, time'),
                where time' = time .
        """
        x = self.out(x)
        x, pos_emb = self.pos_enc(x, offset)
        return x, pos_emb, x_mask

H
Hui Zhang 已提交
87

88 89 90
class Conv2dSubsampling(BaseSubsampling):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
91

H
Hui Zhang 已提交
92

93
class Conv2dSubsampling4(Conv2dSubsampling):
94 95 96 97 98 99 100 101
    """Convolutional 2D subsampling (to 1/4 length)."""

    def __init__(self,
                 idim: int,
                 odim: int,
                 dropout_rate: float,
                 pos_enc_class: nn.Layer=PositionalEncoding):
        """Construct an Conv2dSubsampling4 object.
102

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        Args:
            idim (int): Input dimension.
            odim (int): Output dimension.
            dropout_rate (float): Dropout rate.
        """
        super().__init__(pos_enc_class)
        self.conv = nn.Sequential(
            nn.Conv2D(1, odim, 3, 2),
            nn.ReLU(),
            nn.Conv2D(odim, odim, 3, 2),
            nn.ReLU(), )
        self.out = nn.Sequential(
            nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
        self.subsampling_rate = 4
        # The right context for every conv layer is computed by:
H
Hui Zhang 已提交
118 119
        # (kernel_size - 1) * frame_rate_of_this_layer
        # 6 = (3 - 1) * 1 + (3 - 1) * 2
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        self.right_context = 6

    def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
                ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Subsample x.
        Args:
            x (paddle.Tensor): Input tensor (#batch, time, idim).
            x_mask (paddle.Tensor): Input mask (#batch, 1, time).
            offset (int): position encoding offset.
        Returns:
            paddle.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 4.
            paddle.Tensor: positional encoding
            paddle.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 4.
        """
        x = x.unsqueeze(1)  # (b, c=1, t, f)
        x = self.conv(x)
        b, c, t, f = paddle.shape(x)
        x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
        x, pos_emb = self.pos_enc(x, offset)
        return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]


144
class Conv2dSubsampling6(Conv2dSubsampling):
145 146 147 148 149 150 151 152
    """Convolutional 2D subsampling (to 1/6 length)."""

    def __init__(self,
                 idim: int,
                 odim: int,
                 dropout_rate: float,
                 pos_enc_class: nn.Layer=PositionalEncoding):
        """Construct an Conv2dSubsampling6 object.
153

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        Args:
            idim (int): Input dimension.
            odim (int): Output dimension.
            dropout_rate (float): Dropout rate.
            pos_enc (PositionalEncoding): Custom position encoding layer.
        """
        super().__init__(pos_enc_class)
        self.conv = nn.Sequential(
            nn.Conv2D(1, odim, 3, 2),
            nn.ReLU(),
            nn.Conv2D(odim, odim, 5, 3),
            nn.ReLU(), )
        # O = (I - F + Pstart + Pend) // S + 1
        # when Padding == 0, O = (I - F - S) // S
        self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
        # The right context for every conv layer is computed by:
H
Hui Zhang 已提交
170 171
        # (kernel_size - 1) * frame_rate_of_this_layer
        # 10 = (3 - 1) * 1 + (5 - 1) * 2
172
        self.subsampling_rate = 6
H
Hui Zhang 已提交
173
        self.right_context = 10
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

    def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
                ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Subsample x.
        Args:
            x (paddle.Tensor): Input tensor (#batch, time, idim).
            x_mask (paddle.Tensor): Input mask (#batch, 1, time).
            offset (int): position encoding offset.
        Returns:
            paddle.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 6.
            paddle.Tensor: positional encoding
            paddle.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 6.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        b, c, t, f = paddle.shape(x)
        x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
        x, pos_emb = self.pos_enc(x, offset)
        return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]


197
class Conv2dSubsampling8(Conv2dSubsampling):
198 199 200 201 202 203 204 205
    """Convolutional 2D subsampling (to 1/8 length)."""

    def __init__(self,
                 idim: int,
                 odim: int,
                 dropout_rate: float,
                 pos_enc_class: nn.Layer=PositionalEncoding):
        """Construct an Conv2dSubsampling8 object.
206

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        Args:
            idim (int): Input dimension.
            odim (int): Output dimension.
            dropout_rate (float): Dropout rate.
        """
        super().__init__(pos_enc_class)
        self.conv = nn.Sequential(
            nn.Conv2D(1, odim, 3, 2),
            nn.ReLU(),
            nn.Conv2D(odim, odim, 3, 2),
            nn.ReLU(),
            nn.Conv2D(odim, odim, 3, 2),
            nn.ReLU(), )
        self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2),
                                odim)
        self.subsampling_rate = 8
        # The right context for every conv layer is computed by:
H
Hui Zhang 已提交
224 225
        # (kernel_size - 1) * frame_rate_of_this_layer
        # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        self.right_context = 14

    def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
                ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Subsample x.
        Args:
            x (paddle.Tensor): Input tensor (#batch, time, idim).
            x_mask (paddle.Tensor): Input mask (#batch, 1, time).
            offset (int): position encoding offset.
        Returns:
            paddle.Tensor: Subsampled tensor (#batch, time', odim),
                where time' = time // 8.
            paddle.Tensor: positional encoding
            paddle.Tensor: Subsampled mask (#batch, 1, time'),
                where time' = time // 8.
        """
        x = x.unsqueeze(1)  # (b, c, t, f)
        x = self.conv(x)
        x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
        x, pos_emb = self.pos_enc(x, offset)
        return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]