未验证 提交 a76f2b33 编写于 作者: Z zhangbo9674 提交者: GitHub

Refine trunc uinttest logic (#43016)

* refine trunc uinttest

* refine unittest

* refine ut

* refine fp64 grad check
上级 ba157929
...@@ -30,7 +30,7 @@ class TestTruncOp(OpTest): ...@@ -30,7 +30,7 @@ class TestTruncOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "trunc" self.op_type = "trunc"
self.python_api = paddle.trunc self.python_api = paddle.trunc
self.dtype = np.float64 self.init_dtype_type()
np.random.seed(2021) np.random.seed(2021)
self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)}
self.outputs = {'Out': (np.trunc(self.inputs['X']))} self.outputs = {'Out': (np.trunc(self.inputs['X']))}
...@@ -48,11 +48,19 @@ class TestTruncOp(OpTest): ...@@ -48,11 +48,19 @@ class TestTruncOp(OpTest):
class TestFloatTruncOp(TestTruncOp): class TestFloatTruncOp(TestTruncOp):
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.float32 self.dtype = np.float32
self.__class__.exist_fp64_check_grad = True
def test_check_grad(self):
pass
class TestIntTruncOp(TestTruncOp): class TestIntTruncOp(TestTruncOp):
def init_dtype_type(self): def init_dtype_type(self):
self.dtype = np.int32 self.dtype = np.int32
self.__class__.exist_fp64_check_grad = True
def test_check_grad(self):
pass
class TestTruncAPI(unittest.TestCase): class TestTruncAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册