未验证 提交 7de2db4a 编写于 作者: W whs 提交者: GitHub

Fix grid_sample in cudnn mode (#29124)

上级 7a15e640
...@@ -302,6 +302,9 @@ def grid_sample(x, ...@@ -302,6 +302,9 @@ def grid_sample(x,
if (cudnn_version is not None if (cudnn_version is not None
) and align_corners and mode == 'bilinear' and padding_mode == 'zeros': ) and align_corners and mode == 'bilinear' and padding_mode == 'zeros':
use_cudnn = True use_cudnn = True
# CUDNN always computes gradients for all inputs
x.stop_gradient = False
grid.stop_gradient = False
ipts = {'X': x, 'Grid': grid} ipts = {'X': x, 'Grid': grid}
attrs = { attrs = {
'mode': mode, 'mode': mode,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册