diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 5bb3e99ee302fc8812635f2905086a44a0b95447..1a6790728b137660f644ef3065a270056c4039ce 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -30,7 +30,7 @@ class TestTruncOp(OpTest): def setUp(self): self.op_type = "trunc" self.python_api = paddle.trunc - self.dtype = np.float64 + self.init_dtype_type() np.random.seed(2021) self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} self.outputs = {'Out': (np.trunc(self.inputs['X']))} @@ -48,11 +48,19 @@ class TestTruncOp(OpTest): class TestFloatTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.float32 + self.__class__.exist_fp64_check_grad = True + + def test_check_grad(self): + pass class TestIntTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.int32 + self.__class__.exist_fp64_check_grad = True + + def test_check_grad(self): + pass class TestTruncAPI(unittest.TestCase):