提交 0a0248eb 编写于 作者: W wangruting

modify cinn test

上级 e43ea422
...@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4]) self.w = paddle.randn([4])
self.b = paddle.randn([4]) self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4]) self.w = paddle.randn([4])
self.b = paddle.randn([4]) self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
...@@ -140,7 +140,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -140,7 +140,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim_forward(self): def test_cinn_prim(self):
plat = platform.system() plat = platform.system()
if plat == "Linux": if plat == "Linux":
dy_res = self.train(use_prim=False) dy_res = self.train(use_prim=False)
...@@ -148,7 +148,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -148,7 +148,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
for i in range(len(dy_res)): for i in range(len(dy_res)):
np.testing.assert_allclose( np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 cinn_res[i], dy_res[i], rtol=1e-5, atol=1e-5
) )
else: else:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册