未验证 提交 31a4562a 编写于 作者: 夜雨飘零 提交者: GitHub

[ASR]add squeezeformer model (#2755)

* add squeezeformer model

* change CodeStyle, test=asr

* change CodeStyle, test=asr

* fix subsample rate error, test=asr

* merge classes as required, test=asr

* change CodeStyle, test=asr

* fix missing code, test=asr

* split code to new file, test=asr

* remove rel_shift, test=asr
上级 9bf54716
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 8
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
adaptive_scale: true
cnn_module_kernel: 31
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'stream'
causal: true
use_dynamic_chunk: true
use_dynamic_left_chunk: false
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: squeezeformer
encoder_conf:
encoder_dim: 256 # dimension of attention
output_size: 256 # dimension of output
attention_heads: 4
num_blocks: 12 # the number of encoder blocks
reduce_idx: 5
recover_idx: 11
feed_forward_expansion_factor: 8
input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
attention_dropout_rate: 0.1
adaptive_scale: true
cnn_module_kernel: 31
normalize_before: false
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv1d'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 150
accum_grad: 8
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
......@@ -43,6 +43,7 @@ from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import BiTransformerDecoder
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import SqueezeformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
......@@ -905,6 +906,9 @@ class U2Model(U2DecodeModel):
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'squeezeformer':
encoder = SqueezeformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError(f"not support encoder type:{encoder_type}")
......
......@@ -200,7 +200,12 @@ class MultiHeadedAttention(nn.Layer):
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding."""
def __init__(self, n_head, n_feat, dropout_rate):
def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
......@@ -223,6 +228,39 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, n_feat], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
if init_weights:
self.init_weights()
def init_weights(self):
input_max = (self.h * self.d_k)**-0.5
self.linear_q._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_q._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_k._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_k._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_v._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_v._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_pos._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_pos._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_out._param_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
self.linear_out._bias_attr = paddle.nn.initializer.Uniform(
low=-input_max, high=input_max)
def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding.
......@@ -273,6 +311,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
if self.adaptive_scale:
query = self.ada_scale * query + self.ada_bias
key = self.ada_scale * key + self.ada_bias
value = self.ada_scale * value + self.ada_bias
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
......
......@@ -18,6 +18,7 @@ from typing import Tuple
import paddle
from paddle import nn
from paddle.nn import initializer as I
from typeguard import check_argument_types
from paddlespeech.s2t.modules.align import BatchNorm1D
......@@ -39,7 +40,9 @@ class ConvolutionModule(nn.Layer):
activation: nn.Layer=nn.ReLU(),
norm: str="batch_norm",
causal: bool=False,
bias: bool=True):
bias: bool=True,
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
......@@ -51,6 +54,18 @@ class ConvolutionModule(nn.Layer):
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(1.0))
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, channels], default_initializer=I.Constant(0.0))
self.add_parameter('ada_bias', ada_bias)
self.pointwise_conv1 = Conv1D(
channels,
2 * channels,
......@@ -105,6 +120,28 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels**-0.5
dw_max = self.kernel_size**-0.5
self.pointwise_conv1._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv1._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.depthwise_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
if self.bias:
self.depthwise_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pointwise_conv2._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
if self.bias:
self.pointwise_conv2._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
x: paddle.Tensor,
......@@ -123,6 +160,9 @@ class ConvolutionModule(nn.Layer):
paddle.Tensor: Output tensor (#batch, time, channels).
paddle.Tensor: Output cache tensor (#batch, channels, time')
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension
x = x.transpose([0, 2, 1]) # [B, C, T]
......
from typing import Optional
from typing import Union
import paddle
import paddle.nn.functional as F
from paddle.nn.layer.conv import _ConvNd
__all__ = ['Conv2DValid']
class Conv2DValid(_ConvNd):
"""
Conv2d operator for VALID mode padding.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int=1,
padding: Union[str, int]=0,
dilation: int=1,
groups: int=1,
padding_mode: str='zeros',
weight_attr=None,
bias_attr=None,
data_format="NCHW",
valid_trigx: bool=False,
valid_trigy: bool=False) -> None:
super(Conv2DValid, self).__init__(
in_channels,
out_channels,
kernel_size,
False,
2,
stride=stride,
padding=padding,
padding_mode=padding_mode,
dilation=dilation,
groups=groups,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
self.valid_trigx = valid_trigx
self.valid_trigy = valid_trigy
def _conv_forward(self,
input: paddle.Tensor,
weight: paddle.Tensor,
bias: Optional[paddle.Tensor]):
validx, validy = 0, 0
if self.valid_trigx:
validx = (input.shape[-2] *
(self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2
if self.valid_trigy:
validy = (input.shape[-1] *
(self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2
return F.conv2d(input, weight, bias, self._stride, (validx, validy),
self._dilation, self._groups)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
return self._conv_forward(input, self.weight, self.bias)
......@@ -14,7 +14,10 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
......@@ -22,6 +25,7 @@ from typeguard import check_argument_types
from paddlespeech.s2t.modules.activation import get_activation
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.attention import MultiHeadedAttention
from paddlespeech.s2t.modules.attention import RelPositionMultiHeadedAttention
from paddlespeech.s2t.modules.conformer_convolution import ConvolutionModule
......@@ -29,6 +33,7 @@ from paddlespeech.s2t.modules.embedding import NoPositionalEncoding
from paddlespeech.s2t.modules.embedding import PositionalEncoding
from paddlespeech.s2t.modules.embedding import RelPositionalEncoding
from paddlespeech.s2t.modules.encoder_layer import ConformerEncoderLayer
from paddlespeech.s2t.modules.encoder_layer import SqueezeformerEncoderLayer
from paddlespeech.s2t.modules.encoder_layer import TransformerEncoderLayer
from paddlespeech.s2t.modules.mask import add_optional_chunk_mask
from paddlespeech.s2t.modules.mask import make_non_pad_mask
......@@ -36,12 +41,19 @@ from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedF
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling4
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling6
from paddlespeech.s2t.modules.subsampling import Conv2dSubsampling8
from paddlespeech.s2t.modules.subsampling import DepthwiseConv2DSubsampling4
from paddlespeech.s2t.modules.subsampling import LinearNoSubsampling
from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer1D
from paddlespeech.s2t.modules.time_reduction import TimeReductionLayer2D
from paddlespeech.s2t.modules.time_reduction import TimeReductionLayerStream
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"]
__all__ = [
"BaseEncoder", 'TransformerEncoder', "ConformerEncoder",
"SqueezeformerEncoder"
]
class BaseEncoder(nn.Layer):
......@@ -487,3 +499,366 @@ class ConformerEncoder(BaseEncoder):
normalize_before=normalize_before,
concat_after=concat_after) for _ in range(num_blocks)
])
class SqueezeformerEncoder(nn.Layer):
def __init__(self,
input_size: int,
encoder_dim: int=256,
output_size: int=256,
attention_heads: int=4,
num_blocks: int=12,
reduce_idx: Optional[Union[int, List[int]]]=5,
recover_idx: Optional[Union[int, List[int]]]=11,
feed_forward_expansion_factor: int=4,
dw_stride: bool=False,
input_dropout_rate: float=0.1,
pos_enc_layer_type: str="rel_pos",
time_reduction_layer_type: str="conv1d",
feed_forward_dropout_rate: float=0.1,
attention_dropout_rate: float=0.1,
cnn_module_kernel: int=31,
cnn_norm_type: str="layer_norm",
dropout: float=0.1,
causal: bool=False,
adaptive_scale: bool=True,
activation_type: str="swish",
init_weights: bool=True,
global_cmvn: paddle.nn.Layer=None,
normalize_before: bool=False,
use_dynamic_chunk: bool=False,
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_left_chunk: bool=False):
"""Construct SqueezeformerEncoder
Args:
input_size to use_dynamic_chunk, see in Transformer BaseEncoder.
encoder_dim (int): The hidden dimension of encoder layer.
output_size (int): The output dimension of final projection layer.
attention_heads (int): Num of attention head in attention module.
num_blocks (int): Num of encoder layers.
reduce_idx Optional[Union[int, List[int]]]:
reduce layer index, from 40ms to 80ms per frame.
recover_idx Optional[Union[int, List[int]]]:
recover layer index, from 80ms to 40ms per frame.
feed_forward_expansion_factor (int): Enlarge coefficient of FFN.
dw_stride (bool): Whether do depthwise convolution
on subsampling module.
input_dropout_rate (float): Dropout rate of input projection layer.
pos_enc_layer_type (str): Self attention type.
time_reduction_layer_type (str): Conv1d or Conv2d reduction layer.
cnn_module_kernel (int): Kernel size of CNN module.
activation_type (str): Encoder activation function type.
cnn_module_kernel (int): Kernel size of convolution module.
adaptive_scale (bool): Whether to use adaptive scale.
init_weights (bool): Whether to initialize weights.
causal (bool): whether to use causal convolution or not.
"""
assert check_argument_types()
super().__init__()
self.global_cmvn = global_cmvn
self.reduce_idx: Optional[Union[int, List[int]]] = [reduce_idx] \
if type(reduce_idx) == int else reduce_idx
self.recover_idx: Optional[Union[int, List[int]]] = [recover_idx] \
if type(recover_idx) == int else recover_idx
self.check_ascending_list()
if reduce_idx is None:
self.time_reduce = None
else:
if recover_idx is None:
self.time_reduce = 'normal' # no recovery at the end
else:
self.time_reduce = 'recover' # recovery at the end
assert len(self.reduce_idx) == len(self.recover_idx)
self.reduce_stride = 2
self._output_size = output_size
self.normalize_before = normalize_before
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
attention_dropout_rate)
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate,
adaptive_scale, init_weights)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
encoder_dim, encoder_dim * feed_forward_expansion_factor,
feed_forward_dropout_rate, activation, adaptive_scale, init_weights)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_norm_type, causal, True, adaptive_scale,
init_weights)
self.embed = DepthwiseConv2DSubsampling4(
1, encoder_dim,
RelPositionalEncoding(encoder_dim, dropout_rate=0.1), dw_stride,
input_size, input_dropout_rate, init_weights)
self.preln = LayerNorm(encoder_dim)
self.encoders = paddle.nn.LayerList([
SqueezeformerEncoderLayer(
encoder_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
convolution_layer(*convolution_layer_args),
positionwise_layer(*positionwise_layer_args), normalize_before,
dropout, concat_after) for _ in range(num_blocks)
])
if time_reduction_layer_type == 'conv1d':
time_reduction_layer = TimeReductionLayer1D
time_reduction_layer_args = {
'channel': encoder_dim,
'out_dim': encoder_dim,
}
elif time_reduction_layer_type == 'stream':
time_reduction_layer = TimeReductionLayerStream
time_reduction_layer_args = {
'channel': encoder_dim,
'out_dim': encoder_dim,
}
else:
time_reduction_layer = TimeReductionLayer2D
time_reduction_layer_args = {'encoder_dim': encoder_dim}
self.time_reduction_layer = time_reduction_layer(
**time_reduction_layer_args)
self.time_recover_layer = Linear(encoder_dim, encoder_dim)
self.final_proj = None
if output_size != encoder_dim:
self.final_proj = Linear(encoder_dim, output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: paddle.Tensor,
xs_lens: paddle.Tensor,
decoding_chunk_size: int=0,
num_decoding_left_chunks: int=-1,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, L, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor, lens and mask
"""
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
xs_lens = chunk_masks.squeeze(1).sum(1)
xs = self.preln(xs)
recover_activations: \
List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = []
index = 0
for i, layer in enumerate(self.encoders):
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append(
(xs, chunk_masks, pos_emb, mask_pad))
xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer(
xs, xs_lens, chunk_masks, mask_pad)
pos_emb = pos_emb[:, ::2, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
recover_tensor, recover_chunk_masks, recover_pos_emb, recover_mask_pad = recover_activations[
index]
# recover output length for ctc decode
xs = paddle.repeat_interleave(xs, repeats=2, axis=1)
xs = self.time_recover_layer(xs)
recoverd_t = recover_tensor.shape[1]
xs = recover_tensor + xs[:, :recoverd_t, :]
chunk_masks = recover_chunk_masks
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.final_proj is not None:
xs = self.final_proj(xs)
return xs, masks
def check_ascending_list(self):
if self.reduce_idx is not None:
assert self.reduce_idx == sorted(self.reduce_idx), \
"reduce_idx should be int or ascending list"
if self.recover_idx is not None:
assert self.recover_idx == sorted(self.recover_idx), \
"recover_idx should be int or ascending list"
def calculate_downsampling_factor(self, i: int) -> int:
if self.reduce_idx is None:
return 1
else:
reduce_exp, recover_exp = 0, 0
for exp, rd_idx in enumerate(self.reduce_idx):
if i >= rd_idx:
reduce_exp = exp + 1
if self.recover_idx is not None:
for exp, rc_idx in enumerate(self.recover_idx):
if i >= rc_idx:
recover_exp = exp + 1
return int(2**(reduce_exp - recover_exp))
def forward_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
xs (paddle.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (paddle.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
paddle.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
paddle.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.shape[0] == 1 # batch size must be one
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# tmp_masks is just for interface compatibility, [B=1, C=1, T]
tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.shape[0], att_cache.shape[2]
chunk_size = xs.shape[1]
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
mask_pad = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
mask_pad = mask_pad.unsqueeze(1)
max_att_len: int = 0
recover_activations: \
List[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]] = []
index = 0
xs_lens = paddle.to_tensor([xs.shape[1]], dtype=paddle.int32)
xs = self.preln(xs)
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append(
(xs, att_mask, pos_emb, mask_pad))
xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer(
xs, xs_lens, att_mask, mask_pad)
pos_emb = pos_emb[:, ::2, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad = recover_activations[
index]
# recover output length for ctc decode
xs = paddle.repeat_interleave(xs, repeats=2, axis=1)
xs = self.time_recover_layer(xs)
recoverd_t = recover_tensor.shape[1]
xs = recover_tensor + xs[:, :recoverd_t, :]
att_mask = recover_att_mask
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
factor = self.calculate_downsampling_factor(i)
att_cache1 = att_cache[
i:i + 1][:, :, ::factor, :][:, :, :pos_emb.shape[1] - xs.shape[
1], :]
cnn_cache1 = cnn_cache[i] if cnn_cache.shape[0] > 0 else cnn_cache
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache1,
cnn_cache=cnn_cache1)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
cached_att = new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(0)
cached_att = cached_att.repeat_interleave(repeats=factor, axis=2)
if i == 0:
# record length for the first block as max length
max_att_len = cached_att.shape[2]
r_att_cache.append(cached_att[:, :, :max_att_len, :])
r_cnn_cache.append(cached_cnn)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = paddle.concat(r_att_cache, axis=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
if self.final_proj is not None:
xs = self.final_proj(xs)
return xs, r_att_cache, r_cnn_cache
......@@ -26,7 +26,10 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["TransformerEncoderLayer", "ConformerEncoderLayer"]
__all__ = [
"TransformerEncoderLayer", "ConformerEncoderLayer",
"SqueezeformerEncoderLayer"
]
class TransformerEncoderLayer(nn.Layer):
......@@ -276,3 +279,125 @@ class ConformerEncoderLayer(nn.Layer):
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
class SqueezeformerEncoderLayer(nn.Layer):
"""Encoder layer module."""
def __init__(self,
size: int,
self_attn: paddle.nn.Layer,
feed_forward1: Optional[nn.Layer]=None,
conv_module: Optional[nn.Layer]=None,
feed_forward2: Optional[nn.Layer]=None,
normalize_before: bool=False,
dropout_rate: float=0.1,
concat_after: bool=False):
"""Construct an EncoderLayer object.
Args:
size (int): Input dimension.
self_attn (paddle.nn.Layer): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (paddle.nn.Layer): Convolution module instance.
`ConvlutionLayer` instance can be used as the argument.
feed_forward2 (paddle.nn.Layer): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
super().__init__()
self.size = size
self.self_attn = self_attn
self.layer_norm1 = LayerNorm(size)
self.ffn1 = feed_forward1
self.layer_norm2 = LayerNorm(size)
self.conv_module = conv_module
self.layer_norm3 = LayerNorm(size)
self.ffn2 = feed_forward2
self.layer_norm4 = LayerNorm(size)
self.normalize_before = normalize_before
self.dropout = nn.Dropout(dropout_rate)
self.concat_after = concat_after
if concat_after:
self.concat_linear = Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
x (paddle.Tensor): Input tensor (#batch, time, size).
mask (paddle.Tensor): Mask tensor for the input (#batch, time, time).
(0,0,0) means fake mask.
pos_emb (paddle.Tensor): postional encoding, must not be None
for ConformerEncoderLayer
mask_pad (paddle.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(1, #batch=1, size, cache_t2). First dim will not be used, just
for dy2st.
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time, time).
paddle.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# self attention module
residual = x
if self.normalize_before:
x = self.layer_norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.layer_norm1(x)
# ffn module
residual = x
if self.normalize_before:
x = self.layer_norm2(x)
x = self.ffn1(x)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm2(x)
# conv module
residual = x
if self.normalize_before:
x = self.layer_norm3(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm3(x)
# ffn module
residual = x
if self.normalize_before:
x = self.layer_norm4(x)
x = self.ffn2(x)
# we do not use dropout here since it is inside feed forward function
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm4(x)
return x, mask, new_att_cache, new_cnn_cache
......@@ -16,6 +16,7 @@
"""Positionwise feed forward layer definition."""
import paddle
from paddle import nn
from paddle.nn import initializer as I
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.utils.log import Log
......@@ -32,7 +33,9 @@ class PositionwiseFeedForward(nn.Layer):
idim: int,
hidden_units: int,
dropout_rate: float,
activation: nn.Layer=nn.ReLU()):
activation: nn.Layer=nn.ReLU(),
adaptive_scale: bool=False,
init_weights: bool=False):
"""Construct a PositionwiseFeedForward object.
FeedForward are appied on each position of the sequence.
......@@ -45,10 +48,35 @@ class PositionwiseFeedForward(nn.Layer):
activation (paddle.nn.Layer): Activation function
"""
super().__init__()
self.idim = idim
self.hidden_units = hidden_units
self.w_1 = Linear(idim, hidden_units)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
self.w_2 = Linear(hidden_units, idim)
self.adaptive_scale = adaptive_scale
if self.adaptive_scale:
ada_scale = self.create_parameter(
[1, 1, idim], default_initializer=I.XavierUniform())
self.add_parameter('ada_scale', ada_scale)
ada_bias = self.create_parameter(
[1, 1, idim], default_initializer=I.XavierUniform())
self.add_parameter('ada_bias', ada_bias)
if init_weights:
self.init_weights()
def init_weights(self):
ffn1_max = self.idim**-0.5
ffn2_max = self.hidden_units**-0.5
self.w_1._param_attr = paddle.nn.initializer.Uniform(
low=-ffn1_max, high=ffn1_max)
self.w_1._bias_attr = paddle.nn.initializer.Uniform(
low=-ffn1_max, high=ffn1_max)
self.w_2._param_attr = paddle.nn.initializer.Uniform(
low=-ffn2_max, high=ffn2_max)
self.w_2._bias_attr = paddle.nn.initializer.Uniform(
low=-ffn2_max, high=ffn2_max)
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
"""Forward function.
......@@ -57,4 +85,6 @@ class PositionwiseFeedForward(nn.Layer):
Returns:
output tensor, (B, Lmax, D)
"""
if self.adaptive_scale:
xs = self.ada_scale * xs + self.ada_bias
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
......@@ -29,7 +29,7 @@ logger = Log(__name__).getlog()
__all__ = [
"LinearNoSubsampling", "Conv2dSubsampling4", "Conv2dSubsampling6",
"Conv2dSubsampling8"
"Conv2dSubsampling8", "DepthwiseConv2DSubsampling4"
]
......@@ -249,3 +249,67 @@ class Conv2dSubsampling8(Conv2dSubsampling):
x = self.linear(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
class DepthwiseConv2DSubsampling4(BaseSubsampling):
"""Depthwise Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
pos_enc_class (nn.Layer): position encoding class.
dw_stride (int): Whether do depthwise convolution.
input_size (int): filter bank dimension.
"""
def __init__(self,
idim: int,
odim: int,
pos_enc_class: nn.Layer,
dw_stride: bool=False,
input_size: int=80,
input_dropout_rate: float=0.1,
init_weights: bool=True):
super(DepthwiseConv2DSubsampling4, self).__init__()
self.idim = idim
self.odim = odim
self.pw_conv = Conv2D(
in_channels=idim, out_channels=odim, kernel_size=3, stride=2)
self.act1 = nn.ReLU()
self.dw_conv = Conv2D(
in_channels=odim,
out_channels=odim,
kernel_size=3,
stride=2,
groups=odim if dw_stride else 1)
self.act2 = nn.ReLU()
self.pos_enc = pos_enc_class
self.input_proj = nn.Sequential(
Linear(odim * (((input_size - 1) // 2 - 1) // 2), odim),
nn.Dropout(p=input_dropout_rate))
if init_weights:
linear_max = (odim * input_size / 4)**-0.5
self.input_proj.state_dict()[
'0.weight'] = paddle.nn.initializer.Uniform(
low=-linear_max, high=linear_max)
self.input_proj.state_dict()[
'0.bias'] = paddle.nn.initializer.Uniform(
low=-linear_max, high=linear_max)
self.subsampling_rate = 4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self.right_context = 6
def forward(self, x: paddle.Tensor, x_mask: paddle.Tensor, offset: int=0
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.pw_conv(x)
x = self.act1(x)
x = self.dw_conv(x)
x = self.act2(x)
b, c, t, f = x.shape
x = x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])
x, pos_emb = self.pos_enc(x, offset)
x = self.input_proj(x)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Subsampling layer definition."""
from typing import Tuple
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.s2t import masked_fill
from paddlespeech.s2t.modules.align import Conv1D
from paddlespeech.s2t.modules.conv2d import Conv2DValid
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = [
"TimeReductionLayerStream", "TimeReductionLayer1D", "TimeReductionLayer2D"
]
class TimeReductionLayer1D(nn.Layer):
"""
Modified NeMo,
Squeezeformer Time Reduction procedure.
Downsamples the audio by `stride` in the time dimension.
Args:
channel (int): input dimension of
MultiheadAttentionMechanism and PositionwiseFeedForward
out_dim (int): Output dimension of the module.
kernel_size (int): Conv kernel size for
depthwise convolution in convolution module
stride (int): Downsampling factor in time dimension.
"""
def __init__(self,
channel: int,
out_dim: int,
kernel_size: int=5,
stride: int=2):
super(TimeReductionLayer1D, self).__init__()
self.channel = channel
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.padding = max(0, self.kernel_size - self.stride)
self.dw_conv = Conv1D(
in_channels=channel,
out_channels=channel,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
groups=channel, )
self.pw_conv = Conv1D(
in_channels=channel,
out_channels=out_dim,
kernel_size=1,
stride=1,
padding=0,
groups=1, )
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size**-0.5
pw_max = self.channel**-0.5
self.dw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
xs,
xs_lens: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
mask_pad: paddle.Tensor=paddle.ones((0, 0, 0),
dtype=paddle.bool), ):
xs = xs.transpose([0, 2, 1]) # [B, C, T]
xs = masked_fill(xs, mask_pad.equal(0), 0.0)
xs = self.dw_conv(xs)
xs = self.pw_conv(xs)
xs = xs.transpose([0, 2, 1]) # [B, T, C]
B, T, D = xs.shape
mask = mask[:, ::self.stride, ::self.stride]
mask_pad = mask_pad[:, :, ::self.stride]
L = mask_pad.shape[-1]
# For JIT exporting, we remove F.pad operator.
if L - T < 0:
xs = xs[:, :L - T, :]
else:
dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32)
xs = paddle.concat([xs, dummy_pad], axis=1)
xs_lens = (xs_lens + 1) // 2
return xs, xs_lens, mask, mask_pad
class TimeReductionLayer2D(nn.Layer):
def __init__(self, kernel_size: int=5, stride: int=2, encoder_dim: int=256):
super(TimeReductionLayer2D, self).__init__()
self.encoder_dim = encoder_dim
self.kernel_size = kernel_size
self.dw_conv = Conv2DValid(
in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=(kernel_size, 1),
stride=stride,
valid_trigy=True)
self.pw_conv = Conv2DValid(
in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=1,
stride=1,
valid_trigx=False,
valid_trigy=False)
self.kernel_size = kernel_size
self.stride = stride
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size**-0.5
pw_max = self.encoder_dim**-0.5
self.dw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
xs: paddle.Tensor,
xs_lens: paddle.Tensor,
mask: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
mask_pad: paddle.Tensor=paddle.ones((0, 0, 0), dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
xs = masked_fill(xs, mask_pad.transpose([0, 2, 1]).equal(0), 0.0)
xs = xs.unsqueeze(1)
padding1 = self.kernel_size - self.stride
xs = F.pad(
xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.)
xs = self.dw_conv(xs.transpose([0, 3, 2, 1]))
xs = self.pw_conv(xs).transpose([0, 3, 2, 1]).squeeze(1)
tmp_length = xs.shape[1]
xs_lens = (xs_lens + 1) // 2
padding2 = max(0, (xs_lens.max() - tmp_length).item())
batch_size, hidden = xs.shape[0], xs.shape[-1]
dummy_pad = paddle.zeros(
[batch_size, padding2, hidden], dtype=paddle.float32)
xs = paddle.concat([xs, dummy_pad], axis=1)
mask = mask[:, ::2, ::2]
mask_pad = mask_pad[:, :, ::2]
return xs, xs_lens, mask, mask_pad
class TimeReductionLayerStream(nn.Layer):
"""
Squeezeformer Time Reduction procedure.
Downsamples the audio by `stride` in the time dimension.
Args:
channel (int): input dimension of
MultiheadAttentionMechanism and PositionwiseFeedForward
out_dim (int): Output dimension of the module.
kernel_size (int): Conv kernel size for
depthwise convolution in convolution module
stride (int): Downsampling factor in time dimension.
"""
def __init__(self,
channel: int,
out_dim: int,
kernel_size: int=1,
stride: int=2):
super(TimeReductionLayerStream, self).__init__()
self.channel = channel
self.out_dim = out_dim
self.kernel_size = kernel_size
self.stride = stride
self.dw_conv = Conv1D(
in_channels=channel,
out_channels=channel,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=channel)
self.pw_conv = Conv1D(
in_channels=channel,
out_channels=out_dim,
kernel_size=1,
stride=1,
padding=0,
groups=1)
self.init_weights()
def init_weights(self):
dw_max = self.kernel_size**-0.5
pw_max = self.channel**-0.5
self.dw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.dw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-dw_max, high=dw_max)
self.pw_conv._param_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
self.pw_conv._bias_attr = paddle.nn.initializer.Uniform(
low=-pw_max, high=pw_max)
def forward(
self,
xs,
xs_lens: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool)):
xs = xs.transpose([0, 2, 1]) # [B, C, T]
xs = masked_fill(xs, mask_pad.equal(0), 0.0)
xs = self.dw_conv(xs)
xs = self.pw_conv(xs)
xs = xs.transpose([0, 2, 1]) # [B, T, C]
B, T, D = xs.shape
mask = mask[:, ::self.stride, ::self.stride]
mask_pad = mask_pad[:, :, ::self.stride]
L = mask_pad.shape[-1]
# For JIT exporting, we remove F.pad operator.
if L - T < 0:
xs = xs[:, :L - T, :]
else:
dummy_pad = paddle.zeros([B, L - T, D], dtype=paddle.float32)
xs = paddle.concat([xs, dummy_pad], axis=1)
xs_lens = (xs_lens + 1) // 2
return xs, xs_lens, mask, mask_pad
......@@ -130,8 +130,11 @@ def get_subsample(config):
Returns:
int: subsample rate.
"""
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if config['encoder'] == 'squeezeformer':
return 4
else:
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册