From 3af105637750b0b1eb9b1a0560388957fa029d31 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 29 Sep 2020 14:55:11 +0800 Subject: [PATCH] feat(mge/grad): attach grad immediately GitOrigin-RevId: e3a168c03ab78aafabfe3d41b0e18c61ee07c256 --- .../python/megengine/autodiff/grad_manager.py | 29 ++++++++++++------- .../test/unit/autodiff/test_grad_manger.py | 15 ++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) create mode 100644 imperative/python/test/unit/autodiff/test_grad_manger.py diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index f63deede4..ba8d8a9e7 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from typing import Callable from ..core.autodiff.grad import Grad -from ..tensor import tensor +from ..tensor import Tensor, tensor from ..utils.future import Future backwarding_grad_manager = None @@ -84,10 +84,15 @@ class GradManager: callbacks = [] if isinstance(callbacks, Callable): callbacks = [callbacks] + if isinstance(params, Tensor): + params = [params] for p in params: self._param_dict[id(p)] = p for cb in callbacks: self._call_back_dict[id(p)].append(cb) + if self._grad is not None: + for p in params: + self._record_param(id(p)) return self def _register_after_backward_callback(self, callback): @@ -143,17 +148,21 @@ class GradManager: self._recording = True self._grad = grad for param_id in self._param_dict.keys(): - param_wrapper = self._param_dict[param_id] - callbacks = self._call_back_dict[param_id] + self._record_param(param_id) + grad.__enter__() - def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - gm._gradients[id(p)] = ret + def _record_param(self, param_id): + param_wrapper = self._param_dict[param_id] + callbacks = self._call_back_dict[param_id] - grad.wrt(param_wrapper, callback=callback) - grad.__enter__() + def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + gm._gradients[id(p)] = ret + + # NOTE: override prev callback wrt when called serval times + self._grad.wrt(param_wrapper, callback=callback) def release(self): r"""Stops recording and releases resources for gradients calculation. diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py new file mode 100644 index 000000000..d1afa203f --- /dev/null +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -0,0 +1,15 @@ +import numpy as np + +import megengine as mge +from megengine import autodiff as ad + + +def test_attach_in_with_block(): + a = mge.Parameter([1.0]) + g = ad.GradManager() + with g: + b = a * 3 + g.attach(b) + c = b + 1 + g.backward(c) + assert int(b.grad.numpy()) == 1 -- GitLab