mlm.py 24.6 KB
Newer Older
P
pfZhu 已提交
1
import argparse
小湉湉's avatar
小湉湉 已提交
2 3 4
import os
import sys
from typing import Dict
P
pfZhu 已提交
5
from typing import List
小湉湉's avatar
小湉湉 已提交
6
from typing import Optional
P
pfZhu 已提交
7 8 9 10 11
from typing import Tuple
from typing import Union

import paddle
import yaml
小湉湉's avatar
小湉湉 已提交
12
from paddle import nn
P
pfZhu 已提交
13 14 15 16 17 18
pypath = '..'
for dir_name in os.listdir(pypath):
    dir_path = os.path.join(pypath, dir_name)
    if os.path.isdir(dir_path):
        sys.path.append(dir_path)

小湉湉's avatar
小湉湉 已提交
19 20 21 22
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.masked_fill import masked_fill
P
pfZhu 已提交
23 24
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
小湉湉's avatar
小湉湉 已提交
25
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
小湉湉's avatar
小湉湉 已提交
26 27 28
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
P
pfZhu 已提交
29
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
小湉湉's avatar
小湉湉 已提交
30
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
小湉湉's avatar
小湉湉 已提交
31 32
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
P
pfZhu 已提交
33
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
小湉湉's avatar
小湉湉 已提交
34 35
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
P
pfZhu 已提交
36 37 38
from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.layer_norm import LayerNorm

小湉湉's avatar
小湉湉 已提交
39
from yacs.config import CfgNode
P
pfZhu 已提交
40 41


小湉湉's avatar
小湉湉 已提交
42
# MLM -> Mask Language Model
P
pfZhu 已提交
43 44 45 46 47 48 49 50 51
class mySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._sub_layers.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs

小湉湉's avatar
小湉湉 已提交
52

小湉湉's avatar
小湉湉 已提交
53 54
class MaskInputLayer(nn.Layer):
    def __init__(self, out_features: int) -> None:
小湉湉's avatar
小湉湉 已提交
55
        super().__init__()
P
pfZhu 已提交
56
        self.mask_feature = paddle.create_parameter(
小湉湉's avatar
小湉湉 已提交
57 58 59 60 61
            shape=(1, 1, out_features),
            dtype=paddle.float32,
            default_initializer=paddle.nn.initializer.Assign(
                paddle.normal(shape=(1, 1, out_features))))

小湉湉's avatar
小湉湉 已提交
62 63
    def forward(self, input: paddle.Tensor,
                masked_pos: paddle.Tensor=None) -> paddle.Tensor:
小湉湉's avatar
小湉湉 已提交
64 65 66
        masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
        masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
            paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
P
pfZhu 已提交
67 68
        return masked_input

小湉湉's avatar
小湉湉 已提交
69

P
pfZhu 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
class MLMEncoder(nn.Layer):
    """Conformer encoder module.

    Args:
        idim (int): Input dimension.
        attention_dim (int): Dimension of attention.
        attention_heads (int): The number of heads of multi head attention.
        linear_units (int): The number of units of position-wise feed forward.
        num_blocks (int): The number of decoder blocks.
        dropout_rate (float): Dropout rate.
        positional_dropout_rate (float): Dropout rate after adding positional encoding.
        attention_dropout_rate (float): Dropout rate in attention.
        input_layer (Union[str, paddle.nn.Layer]): Input layer type.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
        positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
        macaron_style (bool): Whether to use macaron style for positionwise layer.
        pos_enc_layer_type (str): Encoder positional encoding layer type.
        selfattention_layer_type (str): Encoder attention layer type.
        activation_type (str): Encoder activation function type.
        use_cnn_module (bool): Whether to use convolution module.
        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
        cnn_module_kernel (int): Kernerl size of convolution module.
        padding_idx (int): Padding idx for input_layer=embed.
        stochastic_depth_rate (float): Maximum probability to skip the encoder layer.

    """
小湉湉's avatar
小湉湉 已提交
101 102

    def __init__(self,
小湉湉's avatar
小湉湉 已提交
103 104
                 idim: int,
                 vocab_size: int=0,
小湉湉's avatar
小湉湉 已提交
105
                 pre_speech_layer: int=0,
小湉湉's avatar
小湉湉 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119
                 attention_dim: int=256,
                 attention_heads: int=4,
                 linear_units: int=2048,
                 num_blocks: int=6,
                 dropout_rate: float=0.1,
                 positional_dropout_rate: float=0.1,
                 attention_dropout_rate: float=0.0,
                 input_layer: str="conv2d",
                 normalize_before: bool=True,
                 concat_after: bool=False,
                 positionwise_layer_type: str="linear",
                 positionwise_conv_kernel_size: int=1,
                 macaron_style: bool=False,
                 pos_enc_layer_type: str="abs_pos",
小湉湉's avatar
小湉湉 已提交
120
                 pos_enc_class=None,
小湉湉's avatar
小湉湉 已提交
121 122 123 124 125 126 127 128
                 selfattention_layer_type: str="selfattn",
                 activation_type: str="swish",
                 use_cnn_module: bool=False,
                 zero_triu: bool=False,
                 cnn_module_kernel: int=31,
                 padding_idx: int=-1,
                 stochastic_depth_rate: float=0.0,
                 text_masking: bool=False):
P
pfZhu 已提交
129
        """Construct an Encoder object."""
小湉湉's avatar
小湉湉 已提交
130
        super().__init__()
P
pfZhu 已提交
131
        self._output_size = attention_dim
小湉湉's avatar
小湉湉 已提交
132
        self.text_masking = text_masking
P
pfZhu 已提交
133
        if self.text_masking:
小湉湉's avatar
小湉湉 已提交
134
            self.text_masking_layer = MaskInputLayer(attention_dim)
P
pfZhu 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        activation = get_activation(activation_type)
        if pos_enc_layer_type == "abs_pos":
            pos_enc_class = PositionalEncoding
        elif pos_enc_layer_type == "scaled_abs_pos":
            pos_enc_class = ScaledPositionalEncoding
        elif pos_enc_layer_type == "rel_pos":
            assert selfattention_layer_type == "rel_selfattn"
            pos_enc_class = RelPositionalEncoding
        elif pos_enc_layer_type == "legacy_rel_pos":
            pos_enc_class = LegacyRelPositionalEncoding
            assert selfattention_layer_type == "legacy_rel_selfattn"
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)

        self.conv_subsampling_factor = 1
        if input_layer == "linear":
            self.embed = nn.Sequential(
                nn.Linear(idim, attention_dim),
                nn.LayerNorm(attention_dim),
                nn.Dropout(dropout_rate),
                nn.ReLU(),
小湉湉's avatar
小湉湉 已提交
156
                pos_enc_class(attention_dim, positional_dropout_rate), )
P
pfZhu 已提交
157 158 159 160 161
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(
                idim,
                attention_dim,
                dropout_rate,
小湉湉's avatar
小湉湉 已提交
162
                pos_enc_class(attention_dim, positional_dropout_rate), )
P
pfZhu 已提交
163 164 165 166
            self.conv_subsampling_factor = 4
        elif input_layer == "embed":
            self.embed = nn.Sequential(
                nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
小湉湉's avatar
小湉湉 已提交
167
                pos_enc_class(attention_dim, positional_dropout_rate), )
P
pfZhu 已提交
168 169 170
        elif input_layer == "mlm":
            self.segment_emb = None
            self.speech_embed = mySequential(
小湉湉's avatar
小湉湉 已提交
171
                MaskInputLayer(idim),
P
pfZhu 已提交
172 173 174
                nn.Linear(idim, attention_dim),
                nn.LayerNorm(attention_dim),
                nn.ReLU(),
小湉湉's avatar
小湉湉 已提交
175
                pos_enc_class(attention_dim, positional_dropout_rate))
P
pfZhu 已提交
176
            self.text_embed = nn.Sequential(
小湉湉's avatar
小湉湉 已提交
177 178 179 180 181 182
                nn.Embedding(
                    vocab_size, attention_dim, padding_idx=padding_idx),
                pos_enc_class(attention_dim, positional_dropout_rate), )
        elif input_layer == "sega_mlm":
            self.segment_emb = nn.Embedding(
                500, attention_dim, padding_idx=padding_idx)
P
pfZhu 已提交
183
            self.speech_embed = mySequential(
小湉湉's avatar
小湉湉 已提交
184
                MaskInputLayer(idim),
P
pfZhu 已提交
185 186 187
                nn.Linear(idim, attention_dim),
                nn.LayerNorm(attention_dim),
                nn.ReLU(),
小湉湉's avatar
小湉湉 已提交
188
                pos_enc_class(attention_dim, positional_dropout_rate))
P
pfZhu 已提交
189
            self.text_embed = nn.Sequential(
小湉湉's avatar
小湉湉 已提交
190 191 192
                nn.Embedding(
                    vocab_size, attention_dim, padding_idx=padding_idx),
                pos_enc_class(attention_dim, positional_dropout_rate), )
P
pfZhu 已提交
193 194 195
        elif isinstance(input_layer, nn.Layer):
            self.embed = nn.Sequential(
                input_layer,
小湉湉's avatar
小湉湉 已提交
196
                pos_enc_class(attention_dim, positional_dropout_rate), )
P
pfZhu 已提交
197 198
        elif input_layer is None:
            self.embed = nn.Sequential(
小湉湉's avatar
小湉湉 已提交
199
                pos_enc_class(attention_dim, positional_dropout_rate))
P
pfZhu 已提交
200 201 202 203 204 205 206
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before

        # self-attention module definition
        if selfattention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
小湉湉's avatar
小湉湉 已提交
207 208
            encoder_selfattn_layer_args = (attention_heads, attention_dim,
                                           attention_dropout_rate, )
P
pfZhu 已提交
209 210 211
        elif selfattention_layer_type == "legacy_rel_selfattn":
            assert pos_enc_layer_type == "legacy_rel_pos"
            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
小湉湉's avatar
小湉湉 已提交
212 213
            encoder_selfattn_layer_args = (attention_heads, attention_dim,
                                           attention_dropout_rate, )
P
pfZhu 已提交
214 215 216
        elif selfattention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
小湉湉's avatar
小湉湉 已提交
217 218
            encoder_selfattn_layer_args = (attention_heads, attention_dim,
                                           attention_dropout_rate, zero_triu, )
P
pfZhu 已提交
219
        else:
小湉湉's avatar
小湉湉 已提交
220 221
            raise ValueError("unknown encoder_attn_layer: " +
                             selfattention_layer_type)
P
pfZhu 已提交
222 223 224 225

        # feed-forward module definition
        if positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
小湉湉's avatar
小湉湉 已提交
226 227
            positionwise_layer_args = (attention_dim, linear_units,
                                       dropout_rate, activation, )
P
pfZhu 已提交
228 229
        elif positionwise_layer_type == "conv1d":
            positionwise_layer = MultiLayeredConv1d
小湉湉's avatar
小湉湉 已提交
230 231 232
            positionwise_layer_args = (attention_dim, linear_units,
                                       positionwise_conv_kernel_size,
                                       dropout_rate, )
P
pfZhu 已提交
233 234
        elif positionwise_layer_type == "conv1d-linear":
            positionwise_layer = Conv1dLinear
小湉湉's avatar
小湉湉 已提交
235 236 237
            positionwise_layer_args = (attention_dim, linear_units,
                                       positionwise_conv_kernel_size,
                                       dropout_rate, )
P
pfZhu 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
        else:
            raise NotImplementedError("Support only linear or conv1d.")

        # convolution module definition
        convolution_layer = ConvolutionModule
        convolution_layer_args = (attention_dim, cnn_module_kernel, activation)

        self.encoders = repeat(
            num_blocks,
            lambda lnum: EncoderLayer(
                attention_dim,
                encoder_selfattn_layer(*encoder_selfattn_layer_args),
                positionwise_layer(*positionwise_layer_args),
                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
                dropout_rate,
                normalize_before,
                concat_after,
小湉湉's avatar
小湉湉 已提交
256
                stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
P
pfZhu 已提交
257 258 259 260 261 262 263 264 265 266 267 268
        self.pre_speech_layer = pre_speech_layer
        self.pre_speech_encoders = repeat(
            self.pre_speech_layer,
            lambda lnum: EncoderLayer(
                attention_dim,
                encoder_selfattn_layer(*encoder_selfattn_layer_args),
                positionwise_layer(*positionwise_layer_args),
                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
                dropout_rate,
                normalize_before,
                concat_after,
小湉湉's avatar
小湉湉 已提交
269
                stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
P
pfZhu 已提交
270 271 272 273
        )
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)

小湉湉's avatar
小湉湉 已提交
274
    def forward(self,
小湉湉's avatar
小湉湉 已提交
275 276 277 278 279 280 281
                speech: paddle.Tensor,
                text: paddle.Tensor,
                masked_pos: paddle.Tensor,
                speech_mask: paddle.Tensor=None,
                text_mask: paddle.Tensor=None,
                speech_seg_pos: paddle.Tensor=None,
                text_seg_pos: paddle.Tensor=None):
P
pfZhu 已提交
282 283 284
        """Encode input sequence.

        """
小湉湉's avatar
小湉湉 已提交
285
        if masked_pos is not None:
小湉湉's avatar
小湉湉 已提交
286
            speech = self.speech_embed(speech, masked_pos)
P
pfZhu 已提交
287
        else:
小湉湉's avatar
小湉湉 已提交
288 289 290
            speech = self.speech_embed(speech)
        if text is not None:
            text = self.text_embed(text)
小湉湉's avatar
小湉湉 已提交
291 292 293
        if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
            speech_seg_emb = self.segment_emb(speech_seg_pos)
            text_seg_emb = self.segment_emb(text_seg_pos)
小湉湉's avatar
小湉湉 已提交
294 295
            text = (text[0] + text_seg_emb, text[1])
            speech = (speech[0] + speech_seg_emb, speech[1])
P
pfZhu 已提交
296
        if self.pre_speech_encoders:
小湉湉's avatar
小湉湉 已提交
297
            speech, _ = self.pre_speech_encoders(speech, speech_mask)
P
pfZhu 已提交
298

小湉湉's avatar
小湉湉 已提交
299 300 301
        if text is not None:
            xs = paddle.concat([speech[0], text[0]], axis=1)
            xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
小湉湉's avatar
小湉湉 已提交
302
            masks = paddle.concat([speech_mask, text_mask], axis=-1)
P
pfZhu 已提交
303
        else:
小湉湉's avatar
小湉湉 已提交
304 305
            xs = speech[0]
            xs_pos_emb = speech[1]
P
pfZhu 已提交
306 307
            masks = speech_mask

小湉湉's avatar
小湉湉 已提交
308
        xs, masks = self.encoders((xs, xs_pos_emb), masks)
P
pfZhu 已提交
309 310 311 312 313 314

        if isinstance(xs, tuple):
            xs = xs[0]
        if self.normalize_before:
            xs = self.after_norm(xs)

小湉湉's avatar
小湉湉 已提交
315
        return xs, masks
P
pfZhu 已提交
316 317 318


class MLMDecoder(MLMEncoder):
小湉湉's avatar
小湉湉 已提交
319
    def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
P
pfZhu 已提交
320 321 322 323 324 325 326 327 328 329 330 331
        """Encode input sequence.

        Args:
            xs (paddle.Tensor): Input tensor (#batch, time, idim).
            masks (paddle.Tensor): Mask tensor (#batch, time).

        Returns:
            paddle.Tensor: Output tensor (#batch, time, attention_dim).
            paddle.Tensor: Mask tensor (#batch, time).

        """
        xs = self.embed(xs)
小湉湉's avatar
小湉湉 已提交
332 333
        xs, masks = self.encoders(xs, masks)

P
pfZhu 已提交
334 335 336 337 338 339 340 341
        if isinstance(xs, tuple):
            xs = xs[0]
        if self.normalize_before:
            xs = self.after_norm(xs)

        return xs, masks


小湉湉's avatar
小湉湉 已提交
342 343
# encoder and decoder is nn.Layer, not str
class MLM(nn.Layer):
小湉湉's avatar
小湉湉 已提交
344 345 346 347 348 349 350 351
    def __init__(self,
                 token_list: Union[Tuple[str, ...], List[str]],
                 odim: int,
                 encoder: nn.Layer,
                 decoder: Optional[nn.Layer],
                 postnet_layers: int=0,
                 postnet_chans: int=0,
                 postnet_filts: int=0,
小湉湉's avatar
小湉湉 已提交
352
                 text_masking: bool=False):
P
pfZhu 已提交
353 354 355 356 357 358 359

        super().__init__()
        self.odim = odim
        self.token_list = token_list.copy()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = encoder.text_embed[0]._num_embeddings
小湉湉's avatar
小湉湉 已提交
360

小湉湉's avatar
小湉湉 已提交
361 362 363
        if self.decoder is None or not (hasattr(self.decoder,
                                                'output_layer') and
                                        self.decoder.output_layer is not None):
P
pfZhu 已提交
364 365
            self.sfc = nn.Linear(self.encoder._output_size, odim)
        else:
小湉湉's avatar
小湉湉 已提交
366
            self.sfc = None
P
pfZhu 已提交
367
        if text_masking:
小湉湉's avatar
小湉湉 已提交
368 369 370 371
            self.text_sfc = nn.Linear(
                self.encoder.text_embed[0]._embedding_dim,
                self.vocab_size,
                weight_attr=self.encoder.text_embed[0]._weight_attr)
P
pfZhu 已提交
372 373
        else:
            self.text_sfc = None
小湉湉's avatar
小湉湉 已提交
374

小湉湉's avatar
小湉湉 已提交
375 376 377 378 379 380 381 382
        self.postnet = (None if postnet_layers == 0 else Postnet(
            idim=self.encoder._output_size,
            odim=odim,
            n_layers=postnet_layers,
            n_chans=postnet_chans,
            n_filts=postnet_filts,
            use_batch_norm=True,
            dropout_rate=0.5, ))
P
pfZhu 已提交
383 384

    def inference(
小湉湉's avatar
小湉湉 已提交
385
            self,
小湉湉's avatar
小湉湉 已提交
386 387 388 389 390 391 392 393
            speech: paddle.Tensor,
            text: paddle.Tensor,
            masked_pos: paddle.Tensor,
            speech_mask: paddle.Tensor,
            text_mask: paddle.Tensor,
            speech_seg_pos: paddle.Tensor,
            text_seg_pos: paddle.Tensor,
            span_bdy: List[int],
小湉湉's avatar
小湉湉 已提交
394
            use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
小湉湉's avatar
小湉湉 已提交
395 396
        '''
        Args:
小湉湉's avatar
小湉湉 已提交
397 398 399 400 401 402 403
            speech (paddle.Tensor): input speech (1, Tmax, D).
            text (paddle.Tensor): input text (1, Tmax2).
            masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
            speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
            text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
            speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
            text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
小湉湉's avatar
小湉湉 已提交
404 405 406 407 408 409 410
            span_bdy (List[int]): masked mel boundary of input speech (2,)
            use_teacher_forcing (bool): whether to use teacher forcing
        Returns:
            List[Tensor]:
                eg:
                [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
        '''
小湉湉's avatar
小湉湉 已提交
411

P
pfZhu 已提交
412 413
        z_cache = None
        if use_teacher_forcing:
小湉湉's avatar
小湉湉 已提交
414 415 416 417 418 419 420 421
            before_outs, zs, *_ = self.forward(
                speech=speech,
                text=text,
                masked_pos=masked_pos,
                speech_mask=speech_mask,
                text_mask=text_mask,
                speech_seg_pos=speech_seg_pos,
                text_seg_pos=text_seg_pos)
P
pfZhu 已提交
422
            if zs is None:
小湉湉's avatar
小湉湉 已提交
423
                zs = before_outs
小湉湉's avatar
小湉湉 已提交
424 425 426

            speech = speech.squeeze(0)
            outs = [speech[:span_bdy[0]]]
小湉湉's avatar
小湉湉 已提交
427
            outs += [zs[0][span_bdy[0]:span_bdy[1]]]
小湉湉's avatar
小湉湉 已提交
428
            outs += [speech[span_bdy[1]:]]
小湉湉's avatar
小湉湉 已提交
429
            return outs
小湉湉's avatar
小湉湉 已提交
430
        return None
P
pfZhu 已提交
431 432


小湉湉's avatar
小湉湉 已提交
433 434 435 436 437 438 439 440 441
class MLMEncAsDecoder(MLM):
    def forward(self,
                speech: paddle.Tensor,
                text: paddle.Tensor,
                masked_pos: paddle.Tensor,
                speech_mask: paddle.Tensor,
                text_mask: paddle.Tensor,
                speech_seg_pos: paddle.Tensor,
                text_seg_pos: paddle.Tensor):
P
pfZhu 已提交
442 443
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
小湉湉's avatar
小湉湉 已提交
444 445 446 447 448 449 450 451
        encoder_out, h_masks = self.encoder(
            speech=speech,
            text=text,
            masked_pos=masked_pos,
            speech_mask=speech_mask,
            text_mask=text_mask,
            speech_seg_pos=speech_seg_pos,
            text_seg_pos=text_seg_pos)
P
pfZhu 已提交
452 453 454 455
        if self.decoder is not None:
            zs, _ = self.decoder(encoder_out, h_masks)
        else:
            zs = encoder_out
小湉湉's avatar
小湉湉 已提交
456
        speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
P
pfZhu 已提交
457
        if self.sfc is not None:
小湉湉's avatar
小湉湉 已提交
458 459 460
            before_outs = paddle.reshape(
                self.sfc(speech_hidden_states),
                (paddle.shape(speech_hidden_states)[0], -1, self.odim))
P
pfZhu 已提交
461 462 463
        else:
            before_outs = speech_hidden_states
        if self.postnet is not None:
小湉湉's avatar
小湉湉 已提交
464 465 466
            after_outs = before_outs + paddle.transpose(
                self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
                [0, 2, 1])
P
pfZhu 已提交
467 468
        else:
            after_outs = None
小湉湉's avatar
小湉湉 已提交
469 470 471 472 473 474 475 476 477 478 479 480
        return before_outs, after_outs, None


class MLMDualMaksing(MLM):
    def forward(self,
                speech: paddle.Tensor,
                text: paddle.Tensor,
                masked_pos: paddle.Tensor,
                speech_mask: paddle.Tensor,
                text_mask: paddle.Tensor,
                speech_seg_pos: paddle.Tensor,
                text_seg_pos: paddle.Tensor):
P
pfZhu 已提交
481 482
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
小湉湉's avatar
小湉湉 已提交
483 484 485 486 487 488 489 490
        encoder_out, h_masks = self.encoder(
            speech=speech,
            text=text,
            masked_pos=masked_pos,
            speech_mask=speech_mask,
            text_mask=text_mask,
            speech_seg_pos=speech_seg_pos,
            text_seg_pos=text_seg_pos)
P
pfZhu 已提交
491 492 493 494
        if self.decoder is not None:
            zs, _ = self.decoder(encoder_out, h_masks)
        else:
            zs = encoder_out
小湉湉's avatar
小湉湉 已提交
495
        speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
P
pfZhu 已提交
496
        if self.text_sfc:
小湉湉's avatar
小湉湉 已提交
497
            text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
小湉湉's avatar
小湉湉 已提交
498 499 500
            text_outs = paddle.reshape(
                self.text_sfc(text_hiddent_states),
                (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
P
pfZhu 已提交
501
        if self.sfc is not None:
小湉湉's avatar
小湉湉 已提交
502 503 504
            before_outs = paddle.reshape(
                self.sfc(speech_hidden_states),
                (paddle.shape(speech_hidden_states)[0], -1, self.odim))
P
pfZhu 已提交
505 506 507
        else:
            before_outs = speech_hidden_states
        if self.postnet is not None:
小湉湉's avatar
小湉湉 已提交
508 509 510
            after_outs = before_outs + paddle.transpose(
                self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
                [0, 2, 1])
P
pfZhu 已提交
511 512
        else:
            after_outs = None
小湉湉's avatar
小湉湉 已提交
513
        return before_outs, after_outs, text_outs
小湉湉's avatar
小湉湉 已提交
514

P
pfZhu 已提交
515 516

def build_model_from_file(config_file, model_file):
小湉湉's avatar
小湉湉 已提交
517

P
pfZhu 已提交
518
    state_dict = paddle.load(model_file)
小湉湉's avatar
小湉湉 已提交
519 520
    model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
        else MLMEncAsDecoder
P
pfZhu 已提交
521 522

    # 构建模型
小湉湉's avatar
小湉湉 已提交
523 524 525
    with open(config_file) as f:
        conf = CfgNode(yaml.safe_load(f))
    model = build_model(conf, model_class)
P
pfZhu 已提交
526
    model.set_state_dict(state_dict)
小湉湉's avatar
小湉湉 已提交
527
    return model, conf
P
pfZhu 已提交
528 529


小湉湉's avatar
小湉湉 已提交
530 531
# select encoder and decoder here
def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
P
pfZhu 已提交
532 533 534 535 536 537 538 539 540 541 542
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]

        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        raise RuntimeError("token_list must be str or list")

小湉湉's avatar
小湉湉 已提交
543
    vocab_size = len(token_list)
小湉湉's avatar
小湉湉 已提交
544
    odim = 80
P
pfZhu 已提交
545 546 547 548

    pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding

    if "conformer" == args.encoder:
小湉湉's avatar
小湉湉 已提交
549 550
        conformer_self_attn_layer_type = args.encoder_conf[
            'selfattention_layer_type']
P
pfZhu 已提交
551 552 553 554 555 556 557 558 559 560 561 562
        conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type']
        conformer_rel_pos_type = "legacy"
        if conformer_rel_pos_type == "legacy":
            if conformer_pos_enc_layer_type == "rel_pos":
                conformer_pos_enc_layer_type = "legacy_rel_pos"
            if conformer_self_attn_layer_type == "rel_selfattn":
                conformer_self_attn_layer_type = "legacy_rel_selfattn"
        elif conformer_rel_pos_type == "latest":
            assert conformer_pos_enc_layer_type != "legacy_rel_pos"
            assert conformer_self_attn_layer_type != "legacy_rel_selfattn"
        else:
            raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}")
小湉湉's avatar
小湉湉 已提交
563 564 565 566 567 568 569 570
        args.encoder_conf[
            'selfattention_layer_type'] = conformer_self_attn_layer_type
        args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type
        if "conformer" == args.decoder:
            args.decoder_conf[
                'selfattention_layer_type'] = conformer_self_attn_layer_type
            args.decoder_conf[
                'pos_enc_layer_type'] = conformer_pos_enc_layer_type
P
pfZhu 已提交
571 572 573 574

    # Encoder
    encoder_class = MLMEncoder

小湉湉's avatar
小湉湉 已提交
575 576
    if 'text_masking' in args.model_conf.keys() and args.model_conf[
            'text_masking']:
P
pfZhu 已提交
577 578 579
        args.encoder_conf['text_masking'] = True
    else:
        args.encoder_conf['text_masking'] = False
小湉湉's avatar
小湉湉 已提交
580 581 582 583 584 585

    encoder = encoder_class(
        args.input_size,
        vocab_size=vocab_size,
        pos_enc_class=pos_enc_class,
        **args.encoder_conf)
P
pfZhu 已提交
586 587 588 589 590 591 592

    # Decoder
    if args.decoder != 'no_decoder':
        decoder_class = MLMDecoder
        decoder = decoder_class(
            idim=0,
            input_layer=None,
小湉湉's avatar
小湉湉 已提交
593
            **args.decoder_conf, )
P
pfZhu 已提交
594 595 596 597 598 599 600 601 602
    else:
        decoder = None

    # Build model
    model = model_class(
        odim=odim,
        encoder=encoder,
        decoder=decoder,
        token_list=token_list,
小湉湉's avatar
小湉湉 已提交
603
        **args.model_conf, )
P
pfZhu 已提交
604 605 606 607 608 609

    # Initialize
    if args.init is not None:
        initialize(model, args.init)

    return model