提交 7cf2c05f 编写于 作者: C chengduoZH

add unit test for input's size is 1x1

上级 b5c92092
......@@ -210,6 +210,19 @@ class TestWithDilation(TestConv2dOp):
self.groups = 3
class TestWithInput1x1Filter1x1(TestConv2dOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
#----------------Conv2dCUDNN----------------
class TestCUDNN(TestConv2dOp):
def init_op_type(self):
......@@ -241,6 +254,12 @@ class TestCUDNNWith1x1(TestWith1x1):
self.op_type = "conv2d"
class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"
class TestDepthwiseConv(TestConv2dOp):
def init_test_case(self):
self.pad = [1, 1]
......@@ -265,7 +284,8 @@ class TestDepthwiseConv2(TestConv2dOp):
self.op_type = "depthwise_conv2d"
# cudnn v5 does not support dilation conv.
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
......
......@@ -200,7 +200,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv2d_transpose"
# #cudnn v5 does not support dilation conv.
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1]
......
......@@ -200,6 +200,22 @@ class TestWith1x1(TestConv3dOp):
self.groups = 3
class TestWithInput1x1Filter1x1(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 1, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_group(self):
self.groups = 3
class TestWithDilation(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
......@@ -240,6 +256,12 @@ class TestWith1x1CUDNN(TestWith1x1):
self.op_type = "conv3d"
class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d"
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation):
......
......@@ -207,7 +207,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv3d_transpose"
# #cudnn v5 does not support dilation conv.
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册