From 927d793746d1dc2e63ebcb7f6ce93845190ec1aa Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 5 Jul 2018 11:17:37 +0000 Subject: [PATCH] simplify test case --- .../fluid/tests/unittests/test_squeeze_op.py | 172 ++++++------------ 1 file changed, 56 insertions(+), 116 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index 6ef5204b7..bca6af2fd 100644 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -19,16 +19,13 @@ from op_test import OpTest # Correct: General. -class TestSqueezeOp1(OpTest): +class TestSqueezeOp(OpTest): def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = (0, 2) - new_shape = (3, 5) - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): self.check_output() @@ -36,138 +33,81 @@ class TestSqueezeOp1(OpTest): def test_check_grad(self): self.check_grad(["X"], "Out") + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) -# Correct: There is mins axis. -class TestSqueezeOp2(OpTest): - def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = (0, -2) - new_shape = (3, 5) - - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} - def test_check_output(self): - self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) # Correct: No axes input. -class TestSqueezeOp3(OpTest): - def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = () - new_shape = (3, 5) - - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) # Correct: Just part of axes be squeezed. -class TestSqueezeOp4(OpTest): - def setUp(self): - ori_shape = (3, 1, 5, 1, 4, 1) - axes = (1, -1) - new_shape = (3, 5, 1, 4) - - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) # Correct: Inplace. -class TestSqueezeOpInplace1(OpTest): - def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = (0, 2) - new_shape = (3, 5) +class TestSqueezeOpInplace1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, 2) + self.new_shape = (3, 5) - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} # Correct: Inplace. There is mins axis. -class TestSqueezeOpInplace2(OpTest): - def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = (0, -2) - new_shape = (3, 5) +class TestSqueezeOpInplace2(TestSqueezeOp): + def inti_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = (0, -2) + self.new_shape = (3, 5) - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} # Correct: Inplace. No axes input. -class TestSqueezeOpInplace3(OpTest): - def setUp(self): - ori_shape = (1, 3, 1, 5) - axes = () - new_shape = (3, 5) - - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() +class TestSqueezeOpInplace3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 5) + self.axes = () + self.new_shape = (3, 5) - def test_check_grad(self): - self.check_grad(["X"], "Out") + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} # Correct: Inpalce. Just part of axes be squeezed. -class TestSqueezeOpInplace4(OpTest): - def setUp(self): - ori_shape = (3, 1, 5, 1, 4, 1) - axes = (1, -1) - new_shape = (3, 5, 1, 4) - - self.op_type = "squeeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestSqueezeOpInplace4(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (3, 5, 1, 4) + + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} if __name__ == "__main__": -- GitLab