提交 902fffcc 编写于 作者: A andyjpaddle

fix slice in sar head

上级 f4b3d49b
...@@ -235,7 +235,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -235,7 +235,8 @@ class ParallelSARDecoder(BaseDecoder):
# cal mask of attention weight # cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios): for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio)) valid_width = min(w, math.ceil(w * valid_ratio))
attn_weight[i, :, :, valid_width:, :] = float('-inf') if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf')
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
attn_weight = F.softmax(attn_weight, axis=-1) attn_weight = F.softmax(attn_weight, axis=-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册