diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index f63deede49ab155fedd98098ec9b3403cf52ee42..ba8d8a9e7399540483a0f3fa1b4fac005ac509e6 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Callable from ..core.autodiff.grad import Grad -from ..tensor import tensor +from ..tensor import Tensor, tensor from ..utils.future import Future backwarding_grad_manager = None @@ -84,10 +84,15 @@ class GradManager: callbacks = [] if isinstance(callbacks, Callable): callbacks = [callbacks] + if isinstance(params, Tensor): + params = [params] for p in params: self._param_dict[id(p)] = p for cb in callbacks: self._call_back_dict[id(p)].append(cb) + if self._grad is not None: + for p in params: + self._record_param(id(p)) return self def _register_after_backward_callback(self, callback): @@ -143,17 +148,21 @@ class GradManager: self._recording = True self._grad = grad for param_id in self._param_dict.keys(): - param_wrapper = self._param_dict[param_id] - callbacks = self._call_back_dict[param_id] + self._record_param(param_id) + grad.__enter__() - def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - gm._gradients[id(p)] = ret + def _record_param(self, param_id): + param_wrapper = self._param_dict[param_id] + callbacks = self._call_back_dict[param_id] - grad.wrt(param_wrapper, callback=callback) - grad.__enter__() + def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + gm._gradients[id(p)] = ret + + # NOTE: override prev callback wrt when called serval times + self._grad.wrt(param_wrapper, callback=callback) def release(self): r"""Stops recording and releases resources for gradients calculation. diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py new file mode 100644 index 0000000000000000000000000000000000000000..d1afa203f0ebed58c9aadfb8a448c94b7451f6ad --- /dev/null +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -0,0 +1,15 @@ +import numpy as np + +import megengine as mge +from megengine import autodiff as ad + + +def test_attach_in_with_block(): + a = mge.Parameter([1.0]) + g = ad.GradManager() + with g: + b = a * 3 + g.attach(b) + c = b + 1 + g.backward(c) + assert int(b.grad.numpy()) == 1