From 78af100c94888ad695e2edbd49f8fc683a2bca87 Mon Sep 17 00:00:00 2001 From: zhulei <563755780@qq.com> Date: Tue, 6 Apr 2021 15:51:35 +0800 Subject: [PATCH] fix test of affine_grid with rocm (#32047) * fix test of affine_grid with rocm * fix test of affine_grid with rocm --- python/paddle/fluid/layers/nn.py | 3 +++ python/paddle/fluid/tests/unittests/test_affine_grid_op.py | 2 ++ python/paddle/nn/functional/vision.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6bc69ffd5cd..34dc1e9b346 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9260,6 +9260,9 @@ def affine_grid(theta, out_shape, name=None): 'affine_grid') else: attrs['output_shape'] = out_shape + if core.is_compiled_with_rocm(): + # ROCM platform do not have MIOPEN kernel for affine_grid + attrs['use_cudnn'] = False helper.append_op( type='affine_grid', diff --git a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py index e4336ab05d5..8277256009e 100644 --- a/python/paddle/fluid/tests/unittests/test_affine_grid_op.py +++ b/python/paddle/fluid/tests/unittests/test_affine_grid_op.py @@ -83,6 +83,8 @@ class TestAffineGridOpCase1(TestAffineGridOp): self.output_shape = np.array([20, 2, 5, 7]).astype("int32") self.dynamic_shape = True self.use_cudnn = True + if paddle.fluid.core.is_compiled_with_rocm(): + self.use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid self.align_corners = True diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 9e04095e7b7..032d5b47eda 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -119,6 +119,8 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): use_cudnn = True else: use_cudnn = False + if core.is_compiled_with_rocm(): + use_cudnn = False # ROCM platform do not have MIOPEN kernel for affine_grid if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ isinstance(out_shape, Variable)): -- GitLab