未验证 提交 8c19d7aa 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix test_conv2d_transpose_op (#31749)

上级 a45c8ca6
...@@ -202,7 +202,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -202,7 +202,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups; int iwo_groups = groups;
int c_groups = 1; int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1) #if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1; iwo_groups = 1;
c_groups = groups; c_groups = groups;
groups = 1; groups = 1;
...@@ -452,7 +452,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -452,7 +452,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups; int iwo_groups = groups;
int c_groups = 1; int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1) #if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1; iwo_groups = 1;
c_groups = groups; c_groups = groups;
groups = 1; groups = 1;
......
...@@ -116,7 +116,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -116,7 +116,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2DTransposeOp(OpTest): class TestConv2DTransposeOp(OpTest):
def setUp(self): def setUp(self):
# init as conv transpose # init as conv transpose
self.dtype = np.float64 self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.need_check_grad = True self.need_check_grad = True
self.is_test = False self.is_test = False
self.use_cudnn = False self.use_cudnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册