提交 8096c193 编写于 作者: W wangruting

modify cinn test

上级 0a0248eb
...@@ -31,10 +31,11 @@ def apply_to_static(net, use_cinn): ...@@ -31,10 +31,11 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer): class PrimeNet(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(PrimeNet, self).__init__() super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4) self.fc = paddle.nn.Linear(4, 64)
def forward(self, x, n_shape, w, b): def forward(self, x, w, b):
y = self.fc(x) y = self.fc(x)
n_shape = y.shape[1:]
out = F.layer_norm(y, n_shape, w, b) out = F.layer_norm(y, n_shape, w, b)
return out[0] return out[0]
...@@ -49,9 +50,8 @@ class TestPrimForward(unittest.TestCase): ...@@ -49,9 +50,8 @@ class TestPrimForward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape[1:] self.w = paddle.randn([64])
self.w = paddle.randn([4]) self.b = paddle.randn([64])
self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
def train(self, use_prim): def train(self, use_prim):
...@@ -66,7 +66,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TestPrimForward(unittest.TestCase):
res = [] res = []
for _ in range(10): for _ in range(10):
out = net(self.x, self.n_shape, self.w, self.b) out = net(self.x, self.w, self.b)
loss = paddle.mean(out) loss = paddle.mean(out)
loss.backward() loss.backward()
sgd.step() sgd.step()
...@@ -104,9 +104,8 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -104,9 +104,8 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(2022) paddle.seed(2022)
self.x = paddle.randn([2, 4]) self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape[1:] self.w = paddle.randn([64])
self.w = paddle.randn([4]) self.b = paddle.randn([64])
self.b = paddle.randn([4])
self.x.stop_gradient = False self.x.stop_gradient = False
def train(self, use_prim): def train(self, use_prim):
...@@ -121,7 +120,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -121,7 +120,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
res = [] res = []
for _ in range(10): for _ in range(10):
out = net(self.x, self.n_shape, self.w, self.b) out = net(self.x, self.w, self.b)
loss = paddle.mean(out) loss = paddle.mean(out)
loss.backward() loss.backward()
sgd.step() sgd.step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册