未验证 提交 51aa02f9 编写于 作者: W wanghuancoder 提交者: GitHub

paddle support stride, fix dy2st check (#10498)

上级 9e911d4a
...@@ -99,6 +99,7 @@ class DotProductAttentionLayer(nn.Layer): ...@@ -99,6 +99,7 @@ class DotProductAttentionLayer(nn.Layer):
logits = paddle.reshape(logits, [n, c, h, w]) logits = paddle.reshape(logits, [n, c, h, w])
if valid_ratios is not None: if valid_ratios is not None:
# cal mask of attention weight # cal mask of attention weight
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
for i, valid_ratio in enumerate(valid_ratios): for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, int(w * valid_ratio + 0.5)) valid_width = min(w, int(w * valid_ratio + 0.5))
if valid_width < w: if valid_width < w:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册