diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index c7750184866850e2ba2eb0bd27a2557a5e212b6e..76f698e64657499c6251e7a70fe86fd266e236fd 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -625,10 +625,12 @@ class U2BaseModel(ASRInterface, nn.Layer): (elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. - `d_k * 2` for att key & value. + `d_k * 2` for att key & value. Default is 0-dims Tensor, + it is used for dy2st. cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` + `cache_t2 == cnn.lorder - 1`. Default is 0-dims Tensor, + it is used for dy2st. Returns: paddle.Tensor: output of current input xs, diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index ad73f5e99f178bf77f5962f6ff1ea96a7cfc22b8..bff2d69bb339e0f9eb7db6607217d309e15aca0f 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -250,11 +250,11 @@ class BaseEncoder(nn.Layer): r_cnn_cache = [] for i, layer in enumerate(self.encoders): # att_cache[i:i+1] = (1, head, cache_t1, d_k*2) - # cnn_cache[i] = (B=1, hidden-dim, cache_t2) + # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, - att_cache=att_cache if elayers == 0 else att_cache[i:i+1], - cnn_cache=cnn_cache if paddle.shape(cnn_cache)[0] == 0 else cnn_cache[i], + att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) # new_att_cache = (1, head, attention_key_size, d_k*2) # new_cnn_cache = (B=1, hidden-dim, cache_t2) diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index d91e3f6ef622fcfb0288f0c4d18bbcccaf0b1ede..9e46cc54045d0edbe87847f73dbffc21ae219408 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -250,6 +250,7 @@ class ConformerEncoderLayer(nn.Layer): # convolution module # Fake new cnn cache here, and then change it in conv_module new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) + cnn_cache = paddle.squeeze(cnn_cache, axis=0) if self.conv_module is not None: residual = x if self.normalize_before: