提交 0a0248eb 编写于 作者: W wangruting

modify cinn test

上级 e43ea422
......@@ -49,7 +49,7 @@ class TestPrimForward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape
self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.x.stop_gradient = False
......@@ -104,7 +104,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape
self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.x.stop_gradient = False
......@@ -140,7 +140,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
def test_cinn_prim_forward(self):
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False)
......@@ -148,7 +148,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
cinn_res[i], dy_res[i], rtol=1e-5, atol=1e-5
)
else:
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册