未验证 提交 f8b2680c 编写于 作者: C chengduo 提交者: GitHub

fix test_conv2d (#14330)

test=develop
上级 c5b6573a
...@@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): ...@@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
def create_test_cudnn_class(parent, cls_name): def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestCUDNNCase(parent): class TestCUDNNCase(parent):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
cls_name = "{0}".format(cls_name) cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase globals()[cls_name] = TestCUDNNCase
create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp") create_test_cudnn_class(TestConv2dOp)
create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1") create_test_cudnn_class(TestWithPad)
create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2") create_test_cudnn_class(TestWithStride)
create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3") create_test_cudnn_class(TestWithGroup)
create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4") create_test_cudnn_class(TestWith1x1)
create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4") create_test_cudnn_class(TestWithInput1x1Filter1x1)
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent): class TestConv2DCUDNNFp16(parent):
...@@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True): ...@@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
max_relative_error=0.02, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
cls_name = "{0}".format(cls_name) cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
TestConv2DCUDNNFp16.__name__ = cls_name TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16 globals()[cls_name] = TestConv2DCUDNNFp16
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False)
TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False) create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestWithStride, grad_check=False)
TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False) create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
create_test_cudnn_fp16_class( create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False) create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
create_test_cudnn_fp16_class(
TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False)
create_test_cudnn_fp16_class(
TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
create_test_cudnn_fp16_class(
TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
# -------TestDepthwiseConv # -------TestDepthwiseConv
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册