未验证 提交 8c19d7aa 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix test_conv2d_transpose_op (#31749)

上级 a45c8ca6
......@@ -202,7 +202,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups;
int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1;
c_groups = groups;
groups = 1;
......@@ -452,7 +452,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int iwo_groups = groups;
int c_groups = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 0, 1)
iwo_groups = 1;
c_groups = groups;
groups = 1;
......
......@@ -116,7 +116,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2DTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.dtype = np.float64
self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
self.need_check_grad = True
self.is_test = False
self.use_cudnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册