diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index caedf1430ec47c442e4c95186120fc04948eb724..f0400f24667d5779947899204528250ee1e5ac15 100755 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -52,6 +52,31 @@ class TestSqueezeOp(OpTest): self.attrs = {"axes": self.axes} +class TestSqueezeFP16Op(OpTest): + def setUp(self): + self.op_type = "squeeze" + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float16")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, 2) + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + class TestSqueezeBF16Op(OpTest): def setUp(self): self.op_type = "squeeze" diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 80e6fee5dedcc8ac3959e887afbd07e6140d68c9..85d21f5646472b5fc478480dfa24e7ef95131cdb 100755 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -50,6 +50,29 @@ class TestUnsqueezeOp(OpTest): self.attrs = {"axes": self.axes} +class TestUnsqueezeFP16Op(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float16")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.ori_shape = (3, 40) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + class TestUnsqueezeBF16Op(OpTest): def setUp(self): self.init_test_case()