diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 47456fa674fdfd5a30d7d258f1a70aee04212459..19652fa30611f09c76d1e455101263aabc3de1e5 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -345,12 +345,12 @@ class GradManager: def __exit__(self, exc_type, exc_val, exc_tb): self.release() - def __and__(self, other): + def __or__(self, other): if isinstance(other, GradManager): return GradManagerGroup([self, other]) return NotImplemented - __rand__ = __and__ + __ror__ = __or__ class GradManagerGroup: @@ -364,8 +364,6 @@ class GradManagerGroup: return NotImplemented return GradManagerGroup([*self._gms, *other._gms]) - __and__ = merge_with - __rand__ = merge_with __or__ = merge_with __ror__ = merge_with diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 23507848952b811daffc8ac3d844f295c10b32f9..e47f733d8bcaa09f7a7de0e2baad6631465c47f6 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -468,7 +468,7 @@ PyObject* GradKeyWrapper::get_priority() { } void GradKeyWrapper::set_priority(pybind11::handle priority) { - m_key->name = py::cast(priority); + m_key->priority = py::cast(priority); } void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { @@ -535,7 +535,7 @@ void GradKey::backward(std::vector tensors, std::vectorpriority; + sm_min_priority = owner->priority + 1; } ~CleanupGuard() { owner->cleanup(); @@ -636,7 +636,7 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { Py_RETURN_FALSE; } -int GradKey::sm_min_priority = 0; +int GradKey::sm_min_priority = std::numeric_limits::min(); GradKey::~GradKey() { cleanup(); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 87c254e59e93d639ab9f90ac06fd3f4475181a7e..bec4cb77d642be507aa441510926e56be5ec288a 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -966,6 +966,7 @@ void init_tensor(py::module m) { .def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::is_attached_to>("is_attached_to") .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") + .def_getset<&GradKeyWrapper::get_priority, &GradKeyWrapper::set_priority>("priority") .finalize(); if (!grad_key_type) throw py::error_already_set(); py::setattr(m, "GradKey", grad_key_type); diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 6919a525ddf1b6277caf521cc1878f33f0cad779..2f82f6c0b9b48ec066f3be3a0e63533a08ca6790 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode): func() worker() + + +def test_2nd_grad_with_manager(): + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + + gm = GradManager().attach([x]) + gm2 = GradManager().attach([x]) + + with gm: + with gm2: + y = F.cos(x) + gm2.backward(y) + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + gm.backward(x.grad) + np.testing.assert_almost_equal( + x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5 + ) + + +def test_grad_manager_group(): + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + + gm = GradManager().attach([x]) + gm2 = GradManager().attach([x]) + + with gm | gm2: + y = F.cos(x) + gm.backward(y) + gm2.backward(y) + np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5) + + x.grad = None + + +def test_grad_manager_group_visibility(): + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + + gm = GradManager().attach([x]) + gm2 = GradManager().attach([x]) + + with gm | gm2: + y = F.cos(x) + gm2.backward(y) + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + gm.backward(x.grad) + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + + +def test_grad_manager_visibility_by_order(): + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + + gm = GradManager().attach([x]) + gm2 = GradManager().attach([x]) + + with gm2: + with gm: + y = F.cos(x) + gm2.backward(y) + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + gm.backward(x.grad) + + np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 248d7ca1bccbee6c016a8126587685ea5187716c..9765ba88d94d895613149630e084c41d6740eef9 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -126,7 +126,7 @@ def test_2nd_grad(): x.grad = None grad2(z, ones) - np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) + np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5) def test_grad_with_tensor_wrapper():