提交 cdcb1a53 编写于 作者: T tianhao zhang

s2t: fix encoder.py

上级 ed2819d7
...@@ -255,6 +255,7 @@ class BaseEncoder(nn.Layer): ...@@ -255,6 +255,7 @@ class BaseEncoder(nn.Layer):
xs, xs,
att_mask, att_mask,
pos_emb, pos_emb,
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1] cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册