From a76f2b33d287d5f7faec7b8fe08eb8d611dc7175 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 27 May 2022 16:02:59 +0800 Subject: [PATCH] Refine trunc uinttest logic (#43016) * refine trunc uinttest * refine unittest * refine ut * refine fp64 grad check --- python/paddle/fluid/tests/unittests/test_trunc_op.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 5bb3e99ee3..1a6790728b 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -30,7 +30,7 @@ class TestTruncOp(OpTest): def setUp(self): self.op_type = "trunc" self.python_api = paddle.trunc - self.dtype = np.float64 + self.init_dtype_type() np.random.seed(2021) self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} self.outputs = {'Out': (np.trunc(self.inputs['X']))} @@ -48,11 +48,19 @@ class TestTruncOp(OpTest): class TestFloatTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.float32 + self.__class__.exist_fp64_check_grad = True + + def test_check_grad(self): + pass class TestIntTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.int32 + self.__class__.exist_fp64_check_grad = True + + def test_check_grad(self): + pass class TestTruncAPI(unittest.TestCase): -- GitLab