diff --git a/imperative/python/test/unit/test_function.py b/imperative/python/test/unit/test_function.py index 990ced2690336a4a40620f4c3c77f908651ea4c6..cef30bd94d43c10375d7caf146288838e613bed7 100644 --- a/imperative/python/test/unit/test_function.py +++ b/imperative/python/test/unit/test_function.py @@ -163,10 +163,11 @@ def test_skip_invalid_grad(): net = Simple(av, bv) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() - with optim.record(): + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() + with gm.record(): loss = net().sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(net.a.numpy(), av - c) np.testing.assert_almost_equal(net.b.numpy(), bv - c) @@ -197,11 +198,12 @@ def test_ste(): av = np.random.random(data_shape).astype(np.float32) net = Simple(av) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() - with optim.record(): + with gm.record(): loss = net() - optim.backward(loss.sum()) + gm.backward(loss.sum()) optim.step() np.testing.assert_almost_equal( @@ -254,10 +256,11 @@ def test_none_in_out_grad(): b = tensor(np.array([2.0], dtype=np.float32)) net = Simple(a, b) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() - with optim.record(): + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() + with gm.record(): loss, _ = net() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal( @@ -290,11 +293,12 @@ def test_zero_grad(): a = tensor(np.array([1.0], dtype=np.float32)) net = Simple(a) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() - with optim.record(): + with gm.record(): loss = net() - optim.backward(loss.sum()) + gm.backward(loss.sum()) optim.step() np.testing.assert_almost_equal( net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),