diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index a93f935e2bd6376420e153b15358f5526888a757..8770b6ee5c9037293c9a38d795c362045ea7df1d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape + self.n_shape = self.x.shape[1:] self.w = paddle.randn([4]) self.b = paddle.randn([4]) self.x.stop_gradient = False @@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape + self.n_shape = self.x.shape[1:] self.w = paddle.randn([4]) self.b = paddle.randn([4]) self.x.stop_gradient = False @@ -140,7 +140,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) - def test_cinn_prim_forward(self): + def test_cinn_prim(self): plat = platform.system() if plat == "Linux": dy_res = self.train(use_prim=False) @@ -148,7 +148,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): for i in range(len(dy_res)): np.testing.assert_allclose( - cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 + cinn_res[i], dy_res[i], rtol=1e-5, atol=1e-5 ) else: pass