提交 69a7c55f 编写于 作者: M Megvii Engine Team

test(mge/function): fix test for new optimizer api

GitOrigin-RevId: 8ae7720fe6340d6ff60ce86981111173a8c1e447
上级 a66d4b8b
......@@ -163,10 +163,11 @@ def test_skip_invalid_grad():
net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss = net().sum()
optim.backward(loss)
gm.backward(loss)
optim.step()
np.testing.assert_almost_equal(net.a.numpy(), av - c)
np.testing.assert_almost_equal(net.b.numpy(), bv - c)
......@@ -197,11 +198,12 @@ def test_ste():
av = np.random.random(data_shape).astype(np.float32)
net = Simple(av)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with optim.record():
with gm.record():
loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step()
np.testing.assert_almost_equal(
......@@ -254,10 +256,11 @@ def test_none_in_out_grad():
b = tensor(np.array([2.0], dtype=np.float32))
net = Simple(a, b)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss, _ = net()
optim.backward(loss)
gm.backward(loss)
optim.step()
np.testing.assert_almost_equal(
......@@ -290,11 +293,12 @@ def test_zero_grad():
a = tensor(np.array([1.0], dtype=np.float32))
net = Simple(a)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with optim.record():
with gm.record():
loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step()
np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册