提交 7cf2c05f 编写于 作者: C chengduoZH

add unit test for input's size is 1x1

上级 b5c92092
...@@ -210,6 +210,19 @@ class TestWithDilation(TestConv2dOp): ...@@ -210,6 +210,19 @@ class TestWithDilation(TestConv2dOp):
self.groups = 3 self.groups = 3
class TestWithInput1x1Filter1x1(TestConv2dOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
class TestCUDNN(TestConv2dOp): class TestCUDNN(TestConv2dOp):
def init_op_type(self): def init_op_type(self):
...@@ -241,6 +254,12 @@ class TestCUDNNWith1x1(TestWith1x1): ...@@ -241,6 +254,12 @@ class TestCUDNNWith1x1(TestWith1x1):
self.op_type = "conv2d" self.op_type = "conv2d"
class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv2d"
class TestDepthwiseConv(TestConv2dOp): class TestDepthwiseConv(TestConv2dOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
...@@ -265,7 +284,8 @@ class TestDepthwiseConv2(TestConv2dOp): ...@@ -265,7 +284,8 @@ class TestDepthwiseConv2(TestConv2dOp):
self.op_type = "depthwise_conv2d" self.op_type = "depthwise_conv2d"
# cudnn v5 does not support dilation conv. # Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation): # class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self): # def init_op_type(self):
# self.op_type = "conv_cudnn" # self.op_type = "conv_cudnn"
......
...@@ -200,7 +200,8 @@ class TestCUDNNWithStride(TestWithStride): ...@@ -200,7 +200,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
# #cudnn v5 does not support dilation conv. # Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation): # class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self): # def init_test_case(self):
# self.pad = [1, 1] # self.pad = [1, 1]
......
...@@ -200,6 +200,22 @@ class TestWith1x1(TestConv3dOp): ...@@ -200,6 +200,22 @@ class TestWith1x1(TestConv3dOp):
self.groups = 3 self.groups = 3
class TestWithInput1x1Filter1x1(TestConv3dOp):
def init_test_case(self):
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 1, 1, 1] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 1, 1, 1]
def init_dilation(self):
self.dilations = [1, 1, 1]
def init_group(self):
self.groups = 3
class TestWithDilation(TestConv3dOp): class TestWithDilation(TestConv3dOp):
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0, 0] self.pad = [0, 0, 0]
...@@ -240,6 +256,12 @@ class TestWith1x1CUDNN(TestWith1x1): ...@@ -240,6 +256,12 @@ class TestWith1x1CUDNN(TestWith1x1):
self.op_type = "conv3d" self.op_type = "conv3d"
class TestWithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
def init_op_type(self):
self.use_cudnn = True
self.op_type = "conv3d"
# FIXME(typhoonzero): find a way to determine if # FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python # using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation): # class TestWithDilationCUDNN(TestWithDilation):
......
...@@ -207,7 +207,8 @@ class TestCUDNNWithStride(TestWithStride): ...@@ -207,7 +207,8 @@ class TestCUDNNWithStride(TestWithStride):
self.op_type = "conv3d_transpose" self.op_type = "conv3d_transpose"
# #cudnn v5 does not support dilation conv. # Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation): # class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self): # def init_test_case(self):
# self.pad = [1, 1, 1] # self.pad = [1, 1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册