diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 2e7fc29b40a7dd1aa71afafb2755eea7131ce2bc..889aee7e99f921d50642bd45dae6001c322907cd 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -259,7 +259,18 @@ def test_reshape(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = x.reshape(5, 2) + + refs = {} + + def f(x): + x = x * 1 + y = x.reshape(5, 2) + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) @@ -270,7 +281,18 @@ def test_subtensor(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = x[1:-1, :2] + + refs = {} + + def f(x): + x = x * 1 + y = x[1:-1, :2] + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal( @@ -283,7 +305,18 @@ def test_IndexingMultiAxisVec(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = x[[0, 2], [0, 2]] + + refs = {} + + def f(x): + x = x * 1 + y = x[[0, 2], [0, 2]] + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal( @@ -296,7 +329,18 @@ def test_AxisAddRemove(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.squeeze(F.expand_dims(x, 2), 0) + + refs = {} + + def f(x): + x = x * 1 + y = F.squeeze(F.expand_dims(x, 2), 0) + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal( @@ -342,7 +386,18 @@ def test_addAxis(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.expand_dims(x, [2, 3]) + + refs = {} + + def f(x): + x = x * 1 + y = F.expand_dims(x, [2, 3]) + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) @@ -353,7 +408,18 @@ def test_removeAxis(): x = mge.Tensor(x_np) grad = Grad().wrt(x, callback=save_to(x)) - y = F.squeeze(x, [2, 3]) + + refs = {} + + def f(x): + x = x * 1 + y = F.squeeze(x, [2, 3]) + refs["x"] = TensorWeakRef(x) + return y + + y = f(x) + for _, r in refs.items(): + assert r() is None grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())