diff --git a/examples/aishell/asr1/conf/chunk_conformer.yaml b/examples/aishell/asr1/conf/chunk_conformer.yaml index 68e852ba77770cd0de9b4c33e93ee3ed777fe674..1ad77f97e983e840b7941e12fd6ccdc5330b7ba5 100644 --- a/examples/aishell/asr1/conf/chunk_conformer.yaml +++ b/examples/aishell/asr1/conf/chunk_conformer.yaml @@ -70,7 +70,7 @@ batch_bins: 0 batch_frames_in: 0 batch_frames_out: 0 batch_frames_inout: 0 -num_workers: 0 +num_workers: 2 subsampling_factor: 1 num_encs: 1 @@ -80,6 +80,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: True optim: adam optim_conf: lr: 0.002 diff --git a/examples/aishell/asr1/conf/conformer.yaml b/examples/aishell/asr1/conf/conformer.yaml index 775a4527d49925e6f0aaf73a2d9b6f7bc37657da..a150a04d55671edf25e5871b1695bcad14710367 100644 --- a/examples/aishell/asr1/conf/conformer.yaml +++ b/examples/aishell/asr1/conf/conformer.yaml @@ -37,6 +37,7 @@ model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option length_normalized_loss: false + init_type: 'kaiming_uniform' ########################################### # Data # @@ -75,6 +76,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: True optim: adam optim_conf: lr: 0.002 diff --git a/examples/aishell/asr1/conf/transformer.yaml b/examples/aishell/asr1/conf/transformer.yaml index 9d2946537b44ed55f59dbebc09de2ef7571324bf..9e08ea0ec79168fb969cb3b13a54be60e94157af 100644 --- a/examples/aishell/asr1/conf/transformer.yaml +++ b/examples/aishell/asr1/conf/transformer.yaml @@ -61,16 +61,17 @@ batch_frames_in: 0 batch_frames_out: 0 batch_frames_inout: 0 preprocess_config: conf/preprocess.yaml -num_workers: 0 +num_workers: 2 subsampling_factor: 1 num_encs: 1 ########################################### # Training # ########################################### -n_epoch: 240 +n_epoch: 30 accum_grad: 2 global_grad_clip: 5.0 +dist_sampler: False optim: adam optim_conf: lr: 0.002 diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index d7bee6d7fe753554916d6b32e38756004507a49f..efcc9629fdbf63981cfdc4cc5b91693e5f3a85ee 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -239,7 +239,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=False, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) self.valid_loader = BatchDataLoader( @@ -260,7 +260,7 @@ class U2Trainer(Trainer): n_iter_processes=config.num_workers, subsampling_factor=1, num_encs=1, - dist_sampler=False, + dist_sampler=config.get('dist_sampler', False), shortest_first=False) logger.info("Setup train/valid Dataloader!") else: diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 910798127ee5c8c7c00893b603ec3ef95dc5be26..51388586f97edc74ac8c91f264d4146b15a5125c 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -36,6 +36,7 @@ from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder +from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.modules.loss import LabelSmoothingLoss from paddlespeech.s2t.modules.mask import make_pad_mask from paddlespeech.s2t.modules.mask import mask_finished_preds @@ -72,6 +73,7 @@ class U2BaseModel(ASRInterface, nn.Layer): assert 0.0 <= ctc_weight <= 1.0, ctc_weight nn.Layer.__init__(self) + # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 @@ -780,9 +782,12 @@ class U2DecodeModel(U2BaseModel): class U2Model(U2DecodeModel): def __init__(self, configs: dict): - vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) - model_conf = configs.get('model_conf', dict()) + init_type = model_conf.get("init_type", None) + with DefaultInitializerContext(init_type): + vocab_size, encoder, decoder, ctc = U2Model._init_from_config( + configs) + super().__init__( vocab_size=vocab_size, encoder=encoder, diff --git a/paddlespeech/s2t/modules/activation.py b/paddlespeech/s2t/modules/activation.py index 4081f7f81a5ca9a0b8594ff01cff23ef6d3eac94..2f387b0d99b68ed5d37cb05a13a030ad49aaa381 100644 --- a/paddlespeech/s2t/modules/activation.py +++ b/paddlespeech/s2t/modules/activation.py @@ -17,6 +17,8 @@ import paddle from paddle import nn from paddle.nn import functional as F +from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -51,7 +53,7 @@ class LinearGLUBlock(nn.Layer): idim (int): input and output dimension """ super().__init__() - self.fc = nn.Linear(idim, idim * 2) + self.fc = Linear(idim, idim * 2) def forward(self, xs): return glu(self.fc(xs), dim=-1) @@ -75,7 +77,7 @@ class ConvGLUBlock(nn.Layer): self.conv_residual = None if in_ch != out_ch: self.conv_residual = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)), name='weight', dim=0) @@ -86,7 +88,7 @@ class ConvGLUBlock(nn.Layer): layers = OrderedDict() if bottlececk_dim == 0: layers['conv'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=in_ch, out_channels=out_ch * 2, kernel_size=(kernel_size, 1)), @@ -106,7 +108,7 @@ class ConvGLUBlock(nn.Layer): dim=0) layers['dropout_in'] = nn.Dropout(p=dropout) layers['conv_bottleneck'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=bottlececk_dim, out_channels=bottlececk_dim, kernel_size=(kernel_size, 1)), @@ -115,7 +117,7 @@ class ConvGLUBlock(nn.Layer): layers['dropout'] = nn.Dropout(p=dropout) layers['glu'] = GLU() layers['conv_out'] = nn.utils.weight_norm( - nn.Conv2D( + Conv2D( in_channels=bottlececk_dim, out_channels=out_ch * 2, kernel_size=(1, 1)), diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py new file mode 100644 index 0000000000000000000000000000000000000000..f889167936115ccc7267037d9046765f83b403bd --- /dev/null +++ b/paddlespeech/s2t/modules/align.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + +from paddlespeech.s2t.modules.initializer import KaimingUniform +""" + To align the initializer between paddle and torch, + the API below are set defalut initializer with priority higger than global initializer. +""" +global_init_type = None + + +class LayerNorm(nn.LayerNorm): + def __init__(self, + normalized_shape, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)) + if bias_attr is None: + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0)) + super(LayerNorm, self).__init__(normalized_shape, epsilon, weight_attr, + bias_attr, name) + + +class BatchNorm1D(nn.BatchNorm1D): + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCL', + name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(1.0)) + if bias_attr is None: + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(0.0)) + super(BatchNorm1D, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, name) + + +class Embedding(nn.Embedding): + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + sparse=False, + weight_attr=None, + name=None): + if weight_attr is None: + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal()) + super(Embedding, self).__init__(num_embeddings, embedding_dim, + padding_idx, sparse, weight_attr, name) + + +class Linear(nn.Linear): + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Linear, self).__init__(in_features, out_features, weight_attr, + bias_attr, name) + + +class Conv1D(nn.Conv1D): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format='NCL'): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + print("set kaiming_uniform") + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Conv1D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, padding_mode, weight_attr, bias_attr, data_format) + + +class Conv2D(nn.Conv2D): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format='NCHW'): + if weight_attr is None: + if global_init_type == "kaiming_uniform": + weight_attr = paddle.ParamAttr(initializer=KaimingUniform()) + if bias_attr is None: + if global_init_type == "kaiming_uniform": + bias_attr = paddle.ParamAttr(initializer=KaimingUniform()) + super(Conv2D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, padding_mode, weight_attr, bias_attr, data_format) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 3d5f8cd1d3aaff3841a8b519bb7b3af178c700ef..438efd2a14151904cb75ff6c72f7be01663bff09 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -22,6 +22,7 @@ import paddle from paddle import nn from paddle.nn import initializer as I +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -48,10 +49,10 @@ class MultiHeadedAttention(nn.Layer): # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q = Linear(n_feat, n_feat) + self.linear_k = Linear(n_feat, n_feat) + self.linear_v = Linear(n_feat, n_feat) + self.linear_out = Linear(n_feat, n_feat) self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv(self, @@ -95,7 +96,7 @@ class MultiHeadedAttention(nn.Layer): mask (paddle.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2). Returns: - paddle.Tensor: Transformed value weighted + paddle.Tensor: Transformed value weighted by the attention score, (#batch, time1, d_model). """ n_batch = value.shape[0] @@ -150,7 +151,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): """ super().__init__(n_head, n_feat, dropout_rate) # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) + self.linear_pos = Linear(n_feat, n_feat, bias_attr=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index 7ec92554eec73b8889335b3a16fd1a34692bb021..89e6526885a2679b8ab09a4e4e4423a15e51ac08 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -21,6 +21,9 @@ import paddle from paddle import nn from typeguard import check_argument_types +from paddlespeech.s2t.modules.align import BatchNorm1D +from paddlespeech.s2t.modules.align import Conv1D +from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -49,7 +52,7 @@ class ConvolutionModule(nn.Layer): """ assert check_argument_types() super().__init__() - self.pointwise_conv1 = nn.Conv1D( + self.pointwise_conv1 = Conv1D( channels, 2 * channels, kernel_size=1, @@ -60,8 +63,8 @@ class ConvolutionModule(nn.Layer): ) # self.lorder is used to distinguish if it's a causal convolution, - # if self.lorder > 0: - # it's a causal convolution, the input will be padded with + # if self.lorder > 0: + # it's a causal convolution, the input will be padded with # `self.lorder` frames on the left in forward (causal conv impl). # else: it's a symmetrical convolution if causal: @@ -73,7 +76,7 @@ class ConvolutionModule(nn.Layer): padding = (kernel_size - 1) // 2 self.lorder = 0 - self.depthwise_conv = nn.Conv1D( + self.depthwise_conv = Conv1D( channels, channels, kernel_size, @@ -87,12 +90,12 @@ class ConvolutionModule(nn.Layer): assert norm in ['batch_norm', 'layer_norm'] if norm == "batch_norm": self.use_layer_norm = False - self.norm = nn.BatchNorm1D(channels) + self.norm = BatchNorm1D(channels) else: self.use_layer_norm = True - self.norm = nn.LayerNorm(channels) + self.norm = LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1D( + self.pointwise_conv2 = Conv1D( channels, channels, kernel_size=1, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 2094182af1a6d31068288d865654bace577b5975..33ad472defba0a86bc945582f386acb406e4c35e 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -18,6 +18,7 @@ from paddle import nn from paddle.nn import functional as F from typeguard import check_argument_types +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.loss import CTCLoss from paddlespeech.s2t.utils import ctc_utils from paddlespeech.s2t.utils.log import Log @@ -69,7 +70,7 @@ class CTCDecoderBase(nn.Layer): self.blank_id = blank_id self.odim = odim self.dropout = nn.Dropout(dropout_rate) - self.ctc_lo = nn.Linear(enc_n_units, self.odim) + self.ctc_lo = Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" self.criterion = CTCLoss( blank=self.blank_id, diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 6b4d959123b19cc23cd42bdcf68491ac6e5f61de..3a851ec62c35f633ce07fd0b4380d92b31d67b3b 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -24,6 +24,9 @@ from paddle import nn from typeguard import check_argument_types from paddlespeech.s2t.decoders.scorers.scorer_interface import BatchScorerInterface +from paddlespeech.s2t.modules.align import Embedding +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.decoder_layer import DecoderLayer from paddlespeech.s2t.modules.embedding import PositionalEncoding @@ -76,21 +79,22 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): concat_after: bool=False, ): assert check_argument_types() + nn.Layer.__init__(self) self.selfattention_layer_type = 'selfattn' attention_dim = encoder_output_size if input_layer == "embed": self.embed = nn.Sequential( - nn.Embedding(vocab_size, attention_dim), + Embedding(vocab_size, attention_dim), PositionalEncoding(attention_dim, positional_dropout_rate), ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12) + self.after_norm = LayerNorm(attention_dim, epsilon=1e-12) self.use_output_layer = use_output_layer - self.output_layer = nn.Linear(attention_dim, vocab_size) + self.output_layer = Linear(attention_dim, vocab_size) self.decoders = nn.LayerList([ DecoderLayer( diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index 520b18dea17928b6fe95bbda804bd89ef28aa904..b7f8694c12623ce82eb6849bcd9438483f513502 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -20,6 +20,8 @@ from typing import Tuple import paddle from paddle import nn +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -62,14 +64,14 @@ class DecoderLayer(nn.Layer): self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, epsilon=1e-12) - self.norm2 = nn.LayerNorm(size, epsilon=1e-12) - self.norm3 = nn.LayerNorm(size, epsilon=1e-12) + self.norm1 = LayerNorm(size, epsilon=1e-12) + self.norm2 = LayerNorm(size, epsilon=1e-12) + self.norm3 = LayerNorm(size, epsilon=1e-12) self.dropout = nn.Dropout(dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after - self.concat_linear1 = nn.Linear(size + size, size) - self.concat_linear2 = nn.Linear(size + size, size) + self.concat_linear1 = Linear(size + size, size) + self.concat_linear2 = Linear(size + size, size) def forward( self, diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 5c8ba0810d00db66a3c96238cf5d243802eb9d7b..c843c0e207054b20a5d3850334198ef6bcb6888c 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -23,6 +23,7 @@ from paddle import nn from typeguard import check_argument_types from paddlespeech.s2t.modules.activation import get_activation +from paddlespeech.s2t.modules.align import LayerNorm from paddlespeech.s2t.modules.attention import MultiHeadedAttention from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule @@ -129,7 +130,7 @@ class BaseEncoder(nn.Layer): d_model=output_size, dropout_rate=positional_dropout_rate), ) self.normalize_before = normalize_before - self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12) + self.after_norm = LayerNorm(output_size, epsilon=1e-12) self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk @@ -457,6 +458,7 @@ class ConformerEncoder(BaseEncoder): cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm'] """ assert check_argument_types() + super().__init__(input_size, output_size, attention_heads, linear_units, num_blocks, dropout_rate, positional_dropout_rate, attention_dropout_rate, input_layer, diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index d39c0695a044cd9cdc5969b547be911565015672..e80a298d621ac87db8ad9f76e48041f05ec18f64 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -20,6 +20,8 @@ from typing import Tuple import paddle from paddle import nn +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -39,7 +41,7 @@ class TransformerEncoderLayer(nn.Layer): normalize_before: bool=True, concat_after: bool=False, ): """Construct an EncoderLayer object. - + Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. @@ -59,15 +61,15 @@ class TransformerEncoderLayer(nn.Layer): super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, epsilon=1e-12) - self.norm2 = nn.LayerNorm(size, epsilon=1e-12) + self.norm1 = LayerNorm(size, epsilon=1e-12) + self.norm2 = LayerNorm(size, epsilon=1e-12) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after # concat_linear may be not used in forward fuction, # but will be saved in the *.pt - self.concat_linear = nn.Linear(size + size, size) + self.concat_linear = Linear(size + size, size) def forward( self, @@ -147,7 +149,7 @@ class ConformerEncoderLayer(nn.Layer): normalize_before: bool=True, concat_after: bool=False, ): """Construct an EncoderLayer object. - + Args: size (int): Input dimension. self_attn (nn.Layer): Self-attention module instance. @@ -174,23 +176,23 @@ class ConformerEncoderLayer(nn.Layer): self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module - self.norm_ff = nn.LayerNorm(size, epsilon=1e-12) # for the FNN module - self.norm_mha = nn.LayerNorm(size, epsilon=1e-12) # for the MHA module + self.norm_ff = LayerNorm(size, epsilon=1e-12) # for the FNN module + self.norm_mha = LayerNorm(size, epsilon=1e-12) # for the MHA module if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm(size, epsilon=1e-12) + self.norm_ff_macaron = LayerNorm(size, epsilon=1e-12) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm( + self.norm_conv = LayerNorm( size, epsilon=1e-12) # for the CNN module - self.norm_final = nn.LayerNorm( + self.norm_final = LayerNorm( size, epsilon=1e-12) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after - self.concat_linear = nn.Linear(size + size, size) + self.concat_linear = Linear(size + size, size) def forward( self, diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..30a04e44fb2965d03be8c6346ef16448ed257bbc --- /dev/null +++ b/paddlespeech/s2t/modules/initializer.py @@ -0,0 +1,172 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from paddle.fluid import framework +from paddle.fluid import unique_name +from paddle.fluid.core import VarDesc +from paddle.fluid.initializer import MSRAInitializer + +__all__ = ['KaimingUniform'] + + +class KaimingUniform(MSRAInitializer): + r"""Implements the Kaiming Uniform initializer + + This class implements the weight initialization from the paper + `Delving Deep into Rectifiers: Surpassing Human-Level Performance on + ImageNet Classification `_ + by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a + robust initialization method that particularly considers the rectifier + nonlinearities. + + In case of Uniform distribution, the range is [-x, x], where + + .. math:: + + x = \sqrt{\frac{1.0}{fan\_in}} + + In case of Normal distribution, the mean is 0 and the standard deviation + is + + .. math:: + + \sqrt{\\frac{2.0}{fan\_in}} + + Args: + fan_in (float32|None): fan_in for Kaiming uniform Initializer. If None, it is\ + inferred from the variable. default is None. + + Note: + It is recommended to set fan_in to None for most cases. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + + linear = nn.Linear(2, + 4, + weight_attr=nn.initializer.KaimingUniform()) + data = paddle.rand([30, 10, 2], dtype='float32') + res = linear(data) + + """ + + def __init__(self, fan_in=None): + super(KaimingUniform, self).__init__( + uniform=True, fan_in=fan_in, seed=0) + + def __call__(self, var, block=None): + """Initialize the input tensor with MSRA initialization. + + Args: + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. + + Returns: + The initialization op + """ + block = self._check_block(block) + + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + f_in, f_out = self._compute_fans(var) + + # If fan_in is passed, use it + fan_in = f_in if self._fan_in is None else self._fan_in + + if self._seed == 0: + self._seed = block.program.random_seed + + # to be compatible of fp16 initalizers + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate( + ".".join(['masra_init', var.name, 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + + if self._uniform: + limit = np.sqrt(1.0 / float(fan_in)) + op = block.append_op( + type="uniform_random", + inputs={}, + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "min": -limit, + "max": limit, + "seed": self._seed + }, + stop_gradient=True) + + else: + std = np.sqrt(2.0 / float(fan_in)) + op = block.append_op( + type="gaussian_random", + outputs={"Out": out_var}, + attrs={ + "shape": out_var.shape, + "dtype": int(out_dtype), + "mean": 0.0, + "std": std, + "seed": self._seed + }, + stop_gradient=True) + + if var.dtype == VarDesc.VarType.FP16 or ( + var.dtype == VarDesc.VarType.BF16 and not self._uniform): + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) + + if not framework.in_dygraph_mode(): + var.op = op + return op + + +class DefaultInitializerContext(object): + """ + egs: + with DefaultInitializerContext("kaiming_uniform"): + code for setup_model + """ + + def __init__(self, init_type=None): + self.init_type = init_type + + def __enter__(self): + if self.init_type is None: + return + else: + from paddlespeech.s2t.modules import align + align.global_init_type = self.init_type + return + + def __exit__(self, exc_type, exc_val, exc_tb): + from paddlespeech.s2t.modules import align + align.global_init_type = None diff --git a/paddlespeech/s2t/modules/positionwise_feed_forward.py b/paddlespeech/s2t/modules/positionwise_feed_forward.py index e2619cd49dc15ef7d9ddb1fbbb991f3fe3eb1c35..c2725dc5cc4aac28d04e44333e185082d7300d44 100644 --- a/paddlespeech/s2t/modules/positionwise_feed_forward.py +++ b/paddlespeech/s2t/modules/positionwise_feed_forward.py @@ -17,6 +17,7 @@ import paddle from paddle import nn +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -44,10 +45,10 @@ class PositionwiseFeedForward(nn.Layer): activation (paddle.nn.Layer): Activation function """ super().__init__() - self.w_1 = nn.Linear(idim, hidden_units) + self.w_1 = Linear(idim, hidden_units) self.activation = activation self.dropout = nn.Dropout(dropout_rate) - self.w_2 = nn.Linear(hidden_units, idim) + self.w_2 = Linear(hidden_units, idim) def forward(self, xs: paddle.Tensor) -> paddle.Tensor: """Forward function. diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 99a8300f246149e924fe741f53934259d404e4e8..88451ddd77f6f89f8597238ddb1236acaa1945d7 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -19,6 +19,9 @@ from typing import Tuple import paddle from paddle import nn +from paddlespeech.s2t.modules.align import Conv2D +from paddlespeech.s2t.modules.align import LayerNorm +from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.embedding import PositionalEncoding from paddlespeech.s2t.utils.log import Log @@ -60,8 +63,8 @@ class LinearNoSubsampling(BaseSubsampling): """ super().__init__(pos_enc_class) self.out = nn.Sequential( - nn.Linear(idim, odim), - nn.LayerNorm(odim, epsilon=1e-12), + Linear(idim, odim), + LayerNorm(odim, epsilon=1e-12), nn.Dropout(dropout_rate), nn.ReLU(), ) self.right_context = 0 @@ -108,12 +111,12 @@ class Conv2dSubsampling4(Conv2dSubsampling): """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), ) self.out = nn.Sequential( - nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) self.subsampling_rate = 4 # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer @@ -160,13 +163,13 @@ class Conv2dSubsampling6(Conv2dSubsampling): """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 5, 3), + Conv2D(odim, odim, 5, 3), nn.ReLU(), ) # O = (I - F + Pstart + Pend) // S + 1 # when Padding == 0, O = (I - F - S) // S - self.linear = nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) + self.linear = Linear(odim * (((idim - 1) // 2 - 2) // 3), odim) # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer # 10 = (3 - 1) * 1 + (5 - 1) * 2 @@ -212,14 +215,14 @@ class Conv2dSubsampling8(Conv2dSubsampling): """ super().__init__(pos_enc_class) self.conv = nn.Sequential( - nn.Conv2D(1, odim, 3, 2), + Conv2D(1, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), - nn.Conv2D(odim, odim, 3, 2), + Conv2D(odim, odim, 3, 2), nn.ReLU(), ) - self.linear = nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), - odim) + self.linear = Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), + odim) self.subsampling_rate = 8 # The right context for every conv layer is computed by: # (kernel_size - 1) * frame_rate_of_this_layer