未验证 提交 f8b2680c 编写于 作者: C chengduo 提交者: GitHub

fix test_conv2d (#14330)

test=develop
上级 c5b6573a
......@@ -225,29 +225,29 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN----------------
def create_test_cudnn_class(parent, cls_name):
def create_test_cudnn_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
cls_name = "{0}".format(cls_name)
cls_name = "{0}_{1}".format(parent.__name__, "CUDNN")
TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase
create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp")
create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1")
create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2")
create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3")
create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4")
create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4")
create_test_cudnn_class(TestConv2dOp)
create_test_cudnn_class(TestWithPad)
create_test_cudnn_class(TestWithStride)
create_test_cudnn_class(TestWithGroup)
create_test_cudnn_class(TestWith1x1)
create_test_cudnn_class(TestWithInput1x1Filter1x1)
#----------------Conv2dCUDNN----------------
def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
def create_test_cudnn_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent):
......@@ -279,23 +279,17 @@ def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
max_relative_error=0.02,
no_grad_set=set(['Input']))
cls_name = "{0}".format(cls_name)
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16")
TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16
create_test_cudnn_fp16_class(
TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False)
create_test_cudnn_fp16_class(
TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False)
create_test_cudnn_fp16_class(
TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False)
create_test_cudnn_fp16_class(
TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False)
create_test_cudnn_fp16_class(
TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
create_test_cudnn_fp16_class(
TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
create_test_cudnn_fp16_class(TestConv2dOp, grad_check=False)
create_test_cudnn_fp16_class(TestWithPad, grad_check=False)
create_test_cudnn_fp16_class(TestWithStride, grad_check=False)
create_test_cudnn_fp16_class(TestWithGroup, grad_check=False)
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False)
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False)
# -------TestDepthwiseConv
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册