提交 927d7937 编写于 作者: C chenweihang

simplify test case

上级 c3151903
......@@ -19,16 +19,13 @@ from op_test import OpTest
# Correct: General.
class TestSqueezeOp1(OpTest):
class TestSqueezeOp(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output()
......@@ -36,138 +33,81 @@ class TestSqueezeOp1(OpTest):
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
# Correct: There is mins axis.
class TestSqueezeOp2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
# Correct: No axes input.
class TestSqueezeOp3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp4(OpTest):
def setUp(self):
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
# Correct: Inplace.
class TestSqueezeOpInplace1(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, 2)
new_shape = (3, 5)
class TestSqueezeOpInplace1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = (0, -2)
new_shape = (3, 5)
class TestSqueezeOpInplace2(TestSqueezeOp):
def inti_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(OpTest):
def setUp(self):
ori_shape = (1, 3, 1, 5)
axes = ()
new_shape = (3, 5)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
class TestSqueezeOpInplace3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(OpTest):
def setUp(self):
ori_shape = (3, 1, 5, 1, 4, 1)
axes = (1, -1)
new_shape = (3, 5, 1, 4)
self.op_type = "squeeze"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"axes": axes, "inplace": True}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSqueezeOpInplace4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册