提交 697f70c0 编写于 作者: M Megvii Engine Team

feat(mge/pytest): add more tests for specialized grad rules

GitOrigin-RevId: 509ef5a2205c92c5bec2981249741418e2a88c28
上级 cf3f58cb
......@@ -259,7 +259,18 @@ def test_reshape():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2)
refs = {}
def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())
......@@ -270,7 +281,18 @@ def test_subtensor():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2]
refs = {}
def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
......@@ -283,7 +305,18 @@ def test_IndexingMultiAxisVec():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]]
refs = {}
def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
......@@ -296,7 +329,18 @@ def test_AxisAddRemove():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(F.expand_dims(x, 2), 0)
refs = {}
def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
......@@ -342,7 +386,18 @@ def test_addAxis():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.expand_dims(x, [2, 3])
refs = {}
def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())
......@@ -353,7 +408,18 @@ def test_removeAxis():
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(x, [2, 3])
refs = {}
def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册