提交 8096c193 编写于 作者: W wangruting

modify cinn test

上级 0a0248eb
......@@ -31,10 +31,11 @@ def apply_to_static(net, use_cinn):
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
self.fc = paddle.nn.Linear(4, 64)
def forward(self, x, n_shape, w, b):
def forward(self, x, w, b):
y = self.fc(x)
n_shape = y.shape[1:]
out = F.layer_norm(y, n_shape, w, b)
return out[0]
......@@ -49,9 +50,8 @@ class TestPrimForward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.w = paddle.randn([64])
self.b = paddle.randn([64])
self.x.stop_gradient = False
def train(self, use_prim):
......@@ -66,7 +66,7 @@ class TestPrimForward(unittest.TestCase):
res = []
for _ in range(10):
out = net(self.x, self.n_shape, self.w, self.b)
out = net(self.x, self.w, self.b)
loss = paddle.mean(out)
loss.backward()
sgd.step()
......@@ -104,9 +104,8 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.n_shape = self.x.shape[1:]
self.w = paddle.randn([4])
self.b = paddle.randn([4])
self.w = paddle.randn([64])
self.b = paddle.randn([64])
self.x.stop_gradient = False
def train(self, use_prim):
......@@ -121,7 +120,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
res = []
for _ in range(10):
out = net(self.x, self.n_shape, self.w, self.b)
out = net(self.x, self.w, self.b)
loss = paddle.mean(out)
loss.backward()
sgd.step()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册