From 0a0248eb0431150c09ccf2f6789d33ec35937e9b Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 13 Feb 2023 03:09:06 +0000 Subject: [PATCH] modify cinn test --- .../dygraph_to_static/test_cinn_prim_layer_norm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index a93f935e2bd..8770b6ee5c9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape + self.n_shape = self.x.shape[1:] self.w = paddle.randn([4]) self.b = paddle.randn([4]) self.x.stop_gradient = False @@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape + self.n_shape = self.x.shape[1:] self.w = paddle.randn([4]) self.b = paddle.randn([4]) self.x.stop_gradient = False @@ -140,7 +140,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - def test_cinn_prim_forward(self): + def test_cinn_prim(self): plat = platform.system() if plat == "Linux": dy_res = self.train(use_prim=False) @@ -148,7 +148,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): for i in range(len(dy_res)): np.testing.assert_allclose( - cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 + cinn_res[i], dy_res[i], rtol=1e-5, atol=1e-5 ) else: pass -- GitLab