From 8096c1930d249d45a35ad9d8f3de0ac3dc599cd9 Mon Sep 17 00:00:00 2001 From: wangruting Date: Mon, 13 Feb 2023 03:12:36 +0000 Subject: [PATCH] modify cinn test --- .../test_cinn_prim_layer_norm.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index 8770b6ee5c9..f69e3a1b3eb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -31,10 +31,11 @@ def apply_to_static(net, use_cinn): class PrimeNet(paddle.nn.Layer): def __init__(self): super(PrimeNet, self).__init__() - self.fc = paddle.nn.Linear(4, 4) + self.fc = paddle.nn.Linear(4, 64) - def forward(self, x, n_shape, w, b): + def forward(self, x, w, b): y = self.fc(x) + n_shape = y.shape[1:] out = F.layer_norm(y, n_shape, w, b) return out[0] @@ -49,9 +50,8 @@ class TestPrimForward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape[1:] - self.w = paddle.randn([4]) - self.b = paddle.randn([4]) + self.w = paddle.randn([64]) + self.b = paddle.randn([64]) self.x.stop_gradient = False def train(self, use_prim): @@ -66,7 +66,7 @@ class TestPrimForward(unittest.TestCase): res = [] for _ in range(10): - out = net(self.x, self.n_shape, self.w, self.b) + out = net(self.x, self.w, self.b) loss = paddle.mean(out) loss.backward() sgd.step() @@ -104,9 +104,8 @@ class TestPrimForwardAndBackward(unittest.TestCase): def setUp(self): paddle.seed(2022) self.x = paddle.randn([2, 4]) - self.n_shape = self.x.shape[1:] - self.w = paddle.randn([4]) - self.b = paddle.randn([4]) + self.w = paddle.randn([64]) + self.b = paddle.randn([64]) self.x.stop_gradient = False def train(self, use_prim): @@ -121,7 +120,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): res = [] for _ in range(10): - out = net(self.x, self.n_shape, self.w, self.b) + out = net(self.x, self.w, self.b) loss = paddle.mean(out) loss.backward() sgd.step() -- GitLab