提交 ef37f73a 编写于 作者: H Hui Zhang

fix cnn cache dy2st shape

上级 d8f03326
......@@ -29,6 +29,9 @@ import paddle
from paddle import jit
from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
......@@ -48,9 +51,6 @@ from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import pad_sequence
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import log_add
from paddlespeech.s2t.utils.utility import UpdateConfig
......@@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
......@@ -625,12 +625,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value. Default is 0-dims Tensor,
it is used for dy2st.
`d_k * 2` for att key & value.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`. Default is 0-dims Tensor,
it is used for dy2st.
`cache_t2 == cnn.lorder - 1`.
Returns:
paddle.Tensor: output of current input xs,
......@@ -641,8 +639,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, att_cache, cnn_cache)
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
......
......@@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
......@@ -105,9 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, cache=att_cache
)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
......@@ -124,7 +122,7 @@ class TransformerEncoderLayer(nn.Layer):
if not self.normalize_before:
x = self.norm2(x)
fake_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
fake_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
return x, mask, new_att_cache, fake_cnn_cache
......@@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor=paddle.zeros([0,0,0,0]),
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
......@@ -211,7 +209,8 @@ class ConformerEncoderLayer(nn.Layer):
att_cache (paddle.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (paddle.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
(1, #batch=1, size, cache_t2). First dim will not be used, just
for dy2st.
Returns:
paddle.Tensor: Output tensor (#batch, time, size).
paddle.Tensor: Mask tensor (#batch, time, time).
......@@ -219,6 +218,8 @@ class ConformerEncoderLayer(nn.Layer):
(#batch=1, head, cache_t1 + time, d_k * 2).
paddle.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# (1, #batch=1, size, cache_t2) -> (#batch=1, size, cache_t2)
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
# whether to use macaron style FFN
if self.feed_forward_macaron is not None:
......@@ -249,8 +250,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
new_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype)
if self.conv_module is not None:
residual = x
if self.normalize_before:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册