提交 8b3ef2a5 编写于 作者: Z zhupengyang 提交者: Tao Luo

all cases use large shape (#22091)

enhanced ops: unsqueeze, squeeze2, strided_slice, unsqueeze,
unsqueeze2, var_conv_2d, spectral_norm, slice, match_matrix_tensor,
nce, pad, pad_constant_like, filter_by_instag
上级 85ba5275
...@@ -87,12 +87,8 @@ class TestFilterByInstagOp(OpTest): ...@@ -87,12 +87,8 @@ class TestFilterByInstagOp(OpTest):
class TestFilterByInstagOp2(OpTest): class TestFilterByInstagOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'filter_by_instag' self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1], x1 = np.random.random((4, 36)).astype('double')
[1, 1, 1, 1]]).astype('double')
x1_lod = [[1, 1, 1, 1]] x1_lod = [[1, 1, 1, 1]]
x2 = np.array([[2], [1], [2], [1]]).astype('int64') x2 = np.array([[2], [1], [2], [1]]).astype('int64')
...@@ -100,7 +96,9 @@ class TestFilterByInstagOp2(OpTest): ...@@ -100,7 +96,9 @@ class TestFilterByInstagOp2(OpTest):
x3 = np.array([1]).astype('int64') x3 = np.array([1]).astype('int64')
out = np.array([[1, 1, 1, 1], [1, 1, 1, 1]]).astype('double') out = np.zeros([2, 36]).astype('double')
out[0] = x1[1]
out[1] = x1[3]
out_lod = [[1, 1]] out_lod = [[1, 1]]
mmap = np.array([[0, 1, 1], [1, 3, 1]]).astype('int64') mmap = np.array([[0, 1, 1], [1, 3, 1]]).astype('int64')
...@@ -134,12 +132,8 @@ class TestFilterByInstagOp2(OpTest): ...@@ -134,12 +132,8 @@ class TestFilterByInstagOp2(OpTest):
class TestFilterByInstagOp3(OpTest): class TestFilterByInstagOp3(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'filter_by_instag' self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1], x1 = np.random.random((4, 36)).astype('double')
[1, 1, 1, 1]]).astype('double')
x1_lod = [[1, 1, 1, 1]] x1_lod = [[1, 1, 1, 1]]
x2 = np.array([[2], [1], [2], [1]]).astype('int64') x2 = np.array([[2], [1], [2], [1]]).astype('int64')
...@@ -147,7 +141,7 @@ class TestFilterByInstagOp3(OpTest): ...@@ -147,7 +141,7 @@ class TestFilterByInstagOp3(OpTest):
x3 = np.array([3]).astype('int64') x3 = np.array([3]).astype('int64')
out = np.array([[0, 0, 0, 0]]).astype('double') out = np.zeros((1, 36)).astype('double')
out_lod = [[1]] out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64') mmap = np.array([[0, 1, 1]]).astype('int64')
...@@ -180,19 +174,15 @@ class TestFilterByInstagOp3(OpTest): ...@@ -180,19 +174,15 @@ class TestFilterByInstagOp3(OpTest):
class TestFilterByInstagOp4(OpTest): class TestFilterByInstagOp4(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'filter_by_instag' self.op_type = 'filter_by_instag'
batch_size = 4
x1_embed_size = 4
fc_cnt = 2
x1 = np.array([[10, 13, 12, 1], [1, 1, 1, 1], [1, 1, 1, 1], x1 = np.random.random((4, 36)).astype('double')
[1, 1, 1, 1]]).astype('double')
x2 = np.array([[2], [1], [2], [1]]).astype('int64') x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]] x2_lod = [[1, 1, 1, 1]]
x3 = np.array([3]).astype('int64') x3 = np.array([3]).astype('int64')
out = np.array([[0, 0, 0, 0]]).astype('double') out = np.zeros((1, 36)).astype('double')
out_lod = [[1]] out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64') mmap = np.array([[0, 1, 1]]).astype('int64')
......
...@@ -87,9 +87,9 @@ class TestMatchMatrixTensorOpCase1(TestMatchMatrixTensorOp): ...@@ -87,9 +87,9 @@ class TestMatchMatrixTensorOpCase1(TestMatchMatrixTensorOp):
class TestMatchMatrixTensorOpCase2(TestMatchMatrixTensorOp): class TestMatchMatrixTensorOpCase2(TestMatchMatrixTensorOp):
def set_data(self): def set_data(self):
ix, iy, h, dim_t = [7, 8, 1, 4] ix, iy, h, dim_t = [105, 120, 1, 4]
x_lod = [[2, 3, 2]] x_lod = [[30, 45, 30]]
y_lod = [[3, 1, 4]] y_lod = [[45, 15, 60]]
self.init_data(ix, x_lod, iy, y_lod, h, dim_t) self.init_data(ix, x_lod, iy, y_lod, h, dim_t)
......
...@@ -88,7 +88,7 @@ class TestNCE(OpTest): ...@@ -88,7 +88,7 @@ class TestNCE(OpTest):
} }
def set_data(self): def set_data(self):
self.generate_data(5, 25, 4, 1, 2, False) self.generate_data(5, 25, 100, 1, 2, False)
def compute(self): def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'], out = nce(self.inputs['Input'], self.inputs['Weight'],
...@@ -116,7 +116,7 @@ class TestNCE(OpTest): ...@@ -116,7 +116,7 @@ class TestNCE(OpTest):
class TestNCECase1Tensor(TestNCE): class TestNCECase1Tensor(TestNCE):
def set_data(self): def set_data(self):
self.generate_data(10, 20, 10, 2, 5, False) self.generate_data(10, 20, 100, 2, 5, False)
class TestNCECase1SelectedRows(unittest.TestCase): class TestNCECase1SelectedRows(unittest.TestCase):
......
...@@ -43,8 +43,8 @@ class TestPadOp(OpTest): ...@@ -43,8 +43,8 @@ class TestPadOp(OpTest):
self.check_grad(['Y'], 'Out') self.check_grad(['Y'], 'Out')
def initTestCase(self): def initTestCase(self):
self.x_shape = (16, 16) self.x_shape = (16, 40)
self.y_shape = (3, 16) self.y_shape = (3, 40)
self.pad_value = 0.1 self.pad_value = 0.1
self.paddings = [(0, 13), (0, 0)] self.paddings = [(0, 13), (0, 0)]
...@@ -59,8 +59,8 @@ class TestCase1(TestPadOp): ...@@ -59,8 +59,8 @@ class TestCase1(TestPadOp):
class TestCase2(TestPadOp): class TestCase2(TestPadOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (4, 3, 4, 4) self.x_shape = (4, 3, 4, 10)
self.y_shape = (2, 3, 2, 4) self.y_shape = (2, 3, 2, 10)
self.paddings = [(0, 2), (0, 0), (0, 2), (0, 0)] self.paddings = [(0, 2), (0, 0), (0, 2), (0, 0)]
self.pad_value = 0.5 self.pad_value = 0.5
......
...@@ -60,14 +60,14 @@ class TestCase1(TestPadOp): ...@@ -60,14 +60,14 @@ class TestCase1(TestPadOp):
class TestCase2(TestPadOp): class TestCase2(TestPadOp):
def initTestCase(self): def initTestCase(self):
self.shape = (2, 2, 2) self.shape = (5, 5, 5)
self.paddings = [(0, 0), (0, 0), (1, 2)] self.paddings = [(0, 0), (0, 0), (1, 2)]
self.pad_value = 1.0 self.pad_value = 1.0
class TestCase3(TestPadOp): class TestCase3(TestPadOp):
def initTestCase(self): def initTestCase(self):
self.shape = (8) self.shape = (100)
self.paddings = [(0, 1)] self.paddings = [(0, 1)]
self.pad_value = 0.9 self.pad_value = 0.9
......
...@@ -458,7 +458,7 @@ class TestFP16_2(OpTest): ...@@ -458,7 +458,7 @@ class TestFP16_2(OpTest):
def config(self): def config(self):
self.dtype = "float16" self.dtype = "float16"
self.input = np.random.random([3, 4, 5]).astype(self.dtype) self.input = np.random.random([3, 4, 10]).astype(self.dtype)
self.starts = [0] self.starts = [0]
self.ends = [1] self.ends = [1]
self.axes = [1] self.axes = [1]
......
...@@ -117,8 +117,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad): ...@@ -117,8 +117,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad):
class TestSpectralNormOp2(TestSpectralNormOp): class TestSpectralNormOp2(TestSpectralNormOp):
def initTestCase(self): def initTestCase(self):
self.weight_shape = (2, 3, 3, 3) self.weight_shape = (2, 6, 3, 3)
self.u_shape = (3, ) self.u_shape = (6, )
self.v_shape = (18, ) self.v_shape = (18, )
self.dim = 1 self.dim = 1
self.power_iters = 0 self.power_iters = 0
......
...@@ -50,25 +50,25 @@ class TestSqueezeOp(OpTest): ...@@ -50,25 +50,25 @@ class TestSqueezeOp(OpTest):
# Correct: There is mins axis. # Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp): class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (1, 3, 1, 5) self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2) self.axes = (0, -2)
self.new_shape = (3, 5) self.new_shape = (20, 5)
# Correct: No axes input. # Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp): class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (1, 3, 1, 5) self.ori_shape = (1, 20, 1, 5)
self.axes = () self.axes = ()
self.new_shape = (3, 5) self.new_shape = (20, 5)
# Correct: Just part of axes be squeezed. # Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp): class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1) self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1) self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4) self.new_shape = (6, 5, 1, 4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -57,17 +57,17 @@ class TestSqueezeOp1(TestSqueezeOp): ...@@ -57,17 +57,17 @@ class TestSqueezeOp1(TestSqueezeOp):
# Correct: No axes input. # Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp): class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (1, 3, 1, 5) self.ori_shape = (1, 20, 1, 5)
self.axes = () self.axes = ()
self.new_shape = (3, 5) self.new_shape = (20, 5)
# Correct: Just part of axes be squeezed. # Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp): class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1) self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1) self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4) self.new_shape = (6, 5, 1, 4)
class TestSqueezeOpError(unittest.TestCase): class TestSqueezeOpError(unittest.TestCase):
......
...@@ -85,7 +85,7 @@ class TestStrideSliceOp(OpTest): ...@@ -85,7 +85,7 @@ class TestStrideSliceOp(OpTest):
class TestStrideSliceOp1(TestStrideSliceOp): class TestStrideSliceOp1(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(6) self.input = np.random.rand(100)
self.axes = [0] self.axes = [0]
self.starts = [3] self.starts = [3]
self.ends = [8] self.ends = [8]
...@@ -95,7 +95,7 @@ class TestStrideSliceOp1(TestStrideSliceOp): ...@@ -95,7 +95,7 @@ class TestStrideSliceOp1(TestStrideSliceOp):
class TestStrideSliceOp2(TestStrideSliceOp): class TestStrideSliceOp2(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(6) self.input = np.random.rand(100)
self.axes = [0] self.axes = [0]
self.starts = [5] self.starts = [5]
self.ends = [0] self.ends = [0]
...@@ -105,7 +105,7 @@ class TestStrideSliceOp2(TestStrideSliceOp): ...@@ -105,7 +105,7 @@ class TestStrideSliceOp2(TestStrideSliceOp):
class TestStrideSliceOp3(TestStrideSliceOp): class TestStrideSliceOp3(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(6) self.input = np.random.rand(100)
self.axes = [0] self.axes = [0]
self.starts = [-1] self.starts = [-1]
self.ends = [-3] self.ends = [-3]
...@@ -115,7 +115,7 @@ class TestStrideSliceOp3(TestStrideSliceOp): ...@@ -115,7 +115,7 @@ class TestStrideSliceOp3(TestStrideSliceOp):
class TestStrideSliceOp4(TestStrideSliceOp): class TestStrideSliceOp4(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(3, 4, 6) self.input = np.random.rand(3, 4, 10)
self.axes = [0, 1, 2] self.axes = [0, 1, 2]
self.starts = [0, -1, 0] self.starts = [0, -1, 0]
self.ends = [2, -3, 5] self.ends = [2, -3, 5]
...@@ -125,7 +125,7 @@ class TestStrideSliceOp4(TestStrideSliceOp): ...@@ -125,7 +125,7 @@ class TestStrideSliceOp4(TestStrideSliceOp):
class TestStrideSliceOp5(TestStrideSliceOp): class TestStrideSliceOp5(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(3, 3, 3) self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2] self.axes = [0, 1, 2]
self.starts = [1, 0, 0] self.starts = [1, 0, 0]
self.ends = [2, 1, 3] self.ends = [2, 1, 3]
...@@ -135,7 +135,7 @@ class TestStrideSliceOp5(TestStrideSliceOp): ...@@ -135,7 +135,7 @@ class TestStrideSliceOp5(TestStrideSliceOp):
class TestStrideSliceOp6(TestStrideSliceOp): class TestStrideSliceOp6(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(3, 3, 3) self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2] self.axes = [0, 1, 2]
self.starts = [1, -1, 0] self.starts = [1, -1, 0]
self.ends = [2, -3, 3] self.ends = [2, -3, 3]
...@@ -145,7 +145,7 @@ class TestStrideSliceOp6(TestStrideSliceOp): ...@@ -145,7 +145,7 @@ class TestStrideSliceOp6(TestStrideSliceOp):
class TestStrideSliceOp7(TestStrideSliceOp): class TestStrideSliceOp7(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(3, 3, 3) self.input = np.random.rand(5, 5, 5)
self.axes = [0, 1, 2] self.axes = [0, 1, 2]
self.starts = [1, 0, 0] self.starts = [1, 0, 0]
self.ends = [2, 2, 3] self.ends = [2, 2, 3]
...@@ -155,7 +155,7 @@ class TestStrideSliceOp7(TestStrideSliceOp): ...@@ -155,7 +155,7 @@ class TestStrideSliceOp7(TestStrideSliceOp):
class TestStrideSliceOp8(TestStrideSliceOp): class TestStrideSliceOp8(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(1, 3, 1) self.input = np.random.rand(1, 100, 1)
self.axes = [1] self.axes = [1]
self.starts = [1] self.starts = [1]
self.ends = [2] self.ends = [2]
...@@ -165,7 +165,7 @@ class TestStrideSliceOp8(TestStrideSliceOp): ...@@ -165,7 +165,7 @@ class TestStrideSliceOp8(TestStrideSliceOp):
class TestStrideSliceOp9(TestStrideSliceOp): class TestStrideSliceOp9(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(1, 3, 1) self.input = np.random.rand(1, 100, 1)
self.axes = [1] self.axes = [1]
self.starts = [-1] self.starts = [-1]
self.ends = [-2] self.ends = [-2]
...@@ -175,7 +175,7 @@ class TestStrideSliceOp9(TestStrideSliceOp): ...@@ -175,7 +175,7 @@ class TestStrideSliceOp9(TestStrideSliceOp):
class TestStrideSliceOp10(TestStrideSliceOp): class TestStrideSliceOp10(TestStrideSliceOp):
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(3, 3) self.input = np.random.rand(10, 10)
self.axes = [0, 1] self.axes = [0, 1]
self.starts = [1, 0] self.starts = [1, 0]
self.ends = [2, 2] self.ends = [2, 2]
......
...@@ -50,33 +50,33 @@ class TestUnsqueezeOp(OpTest): ...@@ -50,33 +50,33 @@ class TestUnsqueezeOp(OpTest):
# Correct: Single input index. # Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp): class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (-1, ) self.axes = (-1, )
self.new_shape = (3, 5, 1) self.new_shape = (20, 5, 1)
# Correct: Mixed input axis. # Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp): class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (0, -1) self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1) self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis. # Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp): class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3) self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes. # Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp): class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1) self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1) self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a list(with tensor) # axes is a list(with tensor)
...@@ -107,9 +107,9 @@ class TestUnsqueezeOp_AxesTensorList(OpTest): ...@@ -107,9 +107,9 @@ class TestUnsqueezeOp_AxesTensorList(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (1, 2) self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5) self.new_shape = (20, 1, 1, 5)
def init_attrs(self): def init_attrs(self):
self.attrs = {} self.attrs = {}
...@@ -117,30 +117,30 @@ class TestUnsqueezeOp_AxesTensorList(OpTest): ...@@ -117,30 +117,30 @@ class TestUnsqueezeOp_AxesTensorList(OpTest):
class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList): class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (-1, ) self.axes = (-1, )
self.new_shape = (3, 5, 1) self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList): class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (0, -1) self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1) self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList): class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3) self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList): class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1) self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1) self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a Tensor # axes is a Tensor
...@@ -166,9 +166,9 @@ class TestUnsqueezeOp_AxesTensor(OpTest): ...@@ -166,9 +166,9 @@ class TestUnsqueezeOp_AxesTensor(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (1, 2) self.axes = (1, 2)
self.new_shape = (3, 1, 1, 5) self.new_shape = (20, 1, 1, 5)
def init_attrs(self): def init_attrs(self):
self.attrs = {} self.attrs = {}
...@@ -176,30 +176,30 @@ class TestUnsqueezeOp_AxesTensor(OpTest): ...@@ -176,30 +176,30 @@ class TestUnsqueezeOp_AxesTensor(OpTest):
class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor): class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (-1, ) self.axes = (-1, )
self.new_shape = (3, 5, 1) self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor): class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (0, -1) self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1) self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor): class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3) self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor): class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1) self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1) self.new_shape = (10, 1, 1, 2, 5, 1)
# test api # test api
......
...@@ -47,33 +47,33 @@ class TestUnsqueezeOp(OpTest): ...@@ -47,33 +47,33 @@ class TestUnsqueezeOp(OpTest):
# Correct: Single input index. # Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp): class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (-1, ) self.axes = (-1, )
self.new_shape = (3, 5, 1) self.new_shape = (20, 5, 1)
# Correct: Mixed input axis. # Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp): class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 5) self.ori_shape = (20, 5)
self.axes = (0, -1) self.axes = (0, -1)
self.new_shape = (1, 3, 5, 1) self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis. # Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp): class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3) self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5) self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes. # Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp): class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (3, 2, 5) self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1) self.axes = (3, 1, 1)
self.new_shape = (3, 1, 1, 2, 5, 1) self.new_shape = (10, 1, 1, 2, 5, 1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -186,8 +186,8 @@ class TestVarConv2dOpCase1(TestVarConv2dOp): ...@@ -186,8 +186,8 @@ class TestVarConv2dOpCase1(TestVarConv2dOp):
output_channel = 2 output_channel = 2
filter_size = [2, 3] filter_size = [2, 3]
stride = [1, 1] stride = [1, 1]
row = [1, 4] row = [1, 10]
col = [3, 2] col = [40, 6]
self.init_data(input_channel, output_channel, filter_size, stride, row, self.init_data(input_channel, output_channel, filter_size, stride, row,
col) col)
...@@ -199,8 +199,8 @@ class TestVarConv2dOpCase2(TestVarConv2dOp): ...@@ -199,8 +199,8 @@ class TestVarConv2dOpCase2(TestVarConv2dOp):
output_channel = 1 output_channel = 1
filter_size = [3, 3] filter_size = [3, 3]
stride = [2, 2] stride = [2, 2]
row = [4, 7] row = [6, 7]
col = [5, 2] col = [8, 2]
self.init_data(input_channel, output_channel, filter_size, stride, row, self.init_data(input_channel, output_channel, filter_size, stride, row,
col) col)
...@@ -212,8 +212,8 @@ class TestVarConv2dOpCase3(TestVarConv2dOp): ...@@ -212,8 +212,8 @@ class TestVarConv2dOpCase3(TestVarConv2dOp):
output_channel = 1 output_channel = 1
filter_size = [3, 3] filter_size = [3, 3]
stride = [2, 2] stride = [2, 2]
row = [7] row = [14]
col = [2] col = [4]
self.init_data(input_channel, output_channel, filter_size, stride, row, self.init_data(input_channel, output_channel, filter_size, stride, row,
col) col)
...@@ -234,7 +234,7 @@ class TestVarConv2dOpCase4(TestVarConv2dOp): ...@@ -234,7 +234,7 @@ class TestVarConv2dOpCase4(TestVarConv2dOp):
class TestVarConv2dOpCase5(TestVarConv2dOp): class TestVarConv2dOpCase5(TestVarConv2dOp):
def set_data(self): def set_data(self):
# set input very small # set input very small
input_channel = 5 input_channel = 50
output_channel = 3 output_channel = 3
filter_size = [3, 3] filter_size = [3, 3]
stride = [1, 1] stride = [1, 1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册