未验证 提交 d8f03326 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #2162 from 0x45f/new_api

[asr]Supprot dy2st for conformer
...@@ -156,13 +156,22 @@ def is_broadcastable(shp1, shp2): ...@@ -156,13 +156,22 @@ def is_broadcastable(shp1, shp2):
return True return True
def broadcast_shape(shp1, shp2):
result = []
for a, b in zip(shp1[::-1], shp2[::-1]):
result.append(max(a, b))
return result[::-1]
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, bshape = broadcast_shape(xs.shape, mask.shape)
mask.shape) mask.stop_gradient = True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) tmp = paddle.ones(shape=[len(bshape)], dtype='int32')
mask = mask.broadcast_to(bshape) for index in range(len(bshape)):
tmp[index] = bshape[index]
mask = mask.broadcast_to(tmp)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs
......
...@@ -625,10 +625,12 @@ class U2BaseModel(ASRInterface, nn.Layer): ...@@ -625,10 +625,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
(elayers, head, cache_t1, d_k * 2), where (elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and `head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`. `cache_t1 == chunk_size * num_decoding_left_chunks`.
`d_k * 2` for att key & value. `d_k * 2` for att key & value. Default is 0-dims Tensor,
it is used for dy2st.
cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer, cnn_cache (paddle.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where (elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1` `cache_t2 == cnn.lorder - 1`. Default is 0-dims Tensor,
it is used for dy2st.
Returns: Returns:
paddle.Tensor: output of current input xs, paddle.Tensor: output of current input xs,
......
...@@ -250,11 +250,11 @@ class BaseEncoder(nn.Layer): ...@@ -250,11 +250,11 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = [] r_cnn_cache = []
for i, layer in enumerate(self.encoders): for i, layer in enumerate(self.encoders):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i] = (B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb, xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
) )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
......
...@@ -250,6 +250,7 @@ class ConformerEncoderLayer(nn.Layer): ...@@ -250,6 +250,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module # convolution module
# Fake new cnn cache here, and then change it in conv_module # Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype) new_cnn_cache = paddle.zeros([0,0,0], dtype=x.dtype)
cnn_cache = paddle.squeeze(cnn_cache, axis=0)
if self.conv_module is not None: if self.conv_module is not None:
residual = x residual = x
if self.normalize_before: if self.normalize_before:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册