diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index abdaf5ea7d6507336e8459b288cb3b9ffb6a88fb..cf4e32fa486c8cb6617997cde2f25c67dbdea839 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -255,6 +255,7 @@ class BaseEncoder(nn.Layer): xs, att_mask, pos_emb, + mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool), att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, cnn_cache=cnn_cache[i:i + 1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 3972ff90afea244b3f82fa86c4166ad5c4639565..4555b535f7bd76b529b4fd397438d5f1579b2928 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -195,8 +195,7 @@ class ConformerEncoderLayer(nn.Layer): x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle. - Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool) + mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool) att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 96d4823e27dccd980288e1ae98efae0333b74295..87d88ee602bc1ec7bc1af8fc96153926bc98045c 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -476,8 +476,12 @@ class PaddleASRConnectionHanddler: # forward chunk (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( - chunk_xs, self.offset, required_cache_size, self.att_cache, - self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool)) + chunk_xs, + self.offset, + required_cache_size, + att_cache=self.att_cache, + cnn_cache=self.cnn_cache, + att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool)) outputs.append(y) # update the global offset, in decoding frame unit