diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 76f698e64657499c6251e7a70fe86fd266e236fd..e19f411cfc0b5a6be9e15b2326631fe289ad44e1 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -29,6 +29,9 @@ import paddle from paddle import jit from paddle import nn +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import pad_sequence +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import load_cmvn @@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.log import Log -from paddlespeech.audio.utils.tensor_utils import add_sos_eos -from paddlespeech.audio.utils.tensor_utils import pad_sequence -from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import log_add from paddlespeech.s2t.utils.utility import UpdateConfig @@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. @@ -625,13 +625,11 @@ class U2BaseModel(ASRInterface, nn.Layer): (elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. - `d_k * 2` for att key & value. Default is 0-dims Tensor, - it is used for dy2st. + `d_k * 2` for att key & value. cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1`. Default is 0-dims Tensor, - it is used for dy2st. - + `cache_t2 == cnn.lorder - 1`. + Returns: paddle.Tensor: output of current input xs, with shape (b=1, chunk_size, hidden-dim). @@ -641,8 +639,8 @@ class U2BaseModel(ASRInterface, nn.Layer): paddle.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. """ - return self.encoder.forward_chunk( - xs, offset, required_cache_size, att_cache, cnn_cache) + return self.encoder.forward_chunk(xs, offset, required_cache_size, + att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 9e46cc54045d0edbe87847f73dbffc21ae219408..5f810dfde98d240245dfd2fa8e89cb17f4f84094 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer): x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), - att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), - cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: @@ -105,9 +105,7 @@ class TransformerEncoderLayer(nn.Layer): if self.normalize_before: x = self.norm1(x) - x_att, new_att_cache = self.self_attn( - x, x, x, mask, cache=att_cache - ) + x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) if self.concat_after: x_concat = paddle.concat((x, x_att), axis=-1) @@ -124,7 +122,7 @@ class TransformerEncoderLayer(nn.Layer): if not self.normalize_before: x = self.norm2(x) - fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) + fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) return x, mask, new_att_cache, fake_cnn_cache @@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer): x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), - att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), - cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]), + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: @@ -211,7 +209,8 @@ class ConformerEncoderLayer(nn.Layer): att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. cnn_cache (paddle.Tensor): Convolution cache in conformer layer - (#batch=1, size, cache_t2) + (1, #batch=1, size, cache_t2). First dim will not be used, just + for dy2st. Returns: paddle.Tensor: Output tensor (#batch, time, size). paddle.Tensor: Mask tensor (#batch, time, time). @@ -219,6 +218,8 @@ class ConformerEncoderLayer(nn.Layer): (#batch=1, head, cache_t1 + time, d_k * 2). paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2). """ + # (1, #batch=1, size, cache_t2) -> (#batch=1, size, cache_t2) + cnn_cache = paddle.squeeze(cnn_cache, axis=0) # whether to use macaron style FFN if self.feed_forward_macaron is not None: @@ -249,8 +250,7 @@ class ConformerEncoderLayer(nn.Layer): # convolution module # Fake new cnn cache here, and then change it in conv_module - new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) - cnn_cache = paddle.squeeze(cnn_cache, axis=0) + new_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) if self.conv_module is not None: residual = x if self.normalize_before: @@ -275,4 +275,4 @@ class ConformerEncoderLayer(nn.Layer): if self.conv_module is not None: x = self.norm_final(x) - return x, mask, new_att_cache, new_cnn_cache \ No newline at end of file + return x, mask, new_att_cache, new_cnn_cache