提交 3af10563 编写于 作者: M Megvii Engine Team

feat(mge/grad): attach grad immediately

GitOrigin-RevId: e3a168c03ab78aafabfe3d41b0e18c61ee07c256
上级 dc3c17ba
...@@ -3,7 +3,7 @@ from contextlib import contextmanager ...@@ -3,7 +3,7 @@ from contextlib import contextmanager
from typing import Callable from typing import Callable
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..tensor import tensor from ..tensor import Tensor, tensor
from ..utils.future import Future from ..utils.future import Future
backwarding_grad_manager = None backwarding_grad_manager = None
...@@ -84,10 +84,15 @@ class GradManager: ...@@ -84,10 +84,15 @@ class GradManager:
callbacks = [] callbacks = []
if isinstance(callbacks, Callable): if isinstance(callbacks, Callable):
callbacks = [callbacks] callbacks = [callbacks]
if isinstance(params, Tensor):
params = [params]
for p in params: for p in params:
self._param_dict[id(p)] = p self._param_dict[id(p)] = p
for cb in callbacks: for cb in callbacks:
self._call_back_dict[id(p)].append(cb) self._call_back_dict[id(p)].append(cb)
if self._grad is not None:
for p in params:
self._record_param(id(p))
return self return self
def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
...@@ -143,17 +148,21 @@ class GradManager: ...@@ -143,17 +148,21 @@ class GradManager:
self._recording = True self._recording = True
self._grad = grad self._grad = grad
for param_id in self._param_dict.keys(): for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id] self._record_param(param_id)
callbacks = self._call_back_dict[param_id] grad.__enter__()
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): def _record_param(self, param_id):
ret = grad param_wrapper = self._param_dict[param_id]
for cb in callbacks: callbacks = self._call_back_dict[param_id]
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad.wrt(param_wrapper, callback=callback) def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
grad.__enter__() ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
# NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback)
def release(self): def release(self):
r"""Stops recording and releases resources for gradients calculation. r"""Stops recording and releases resources for gradients calculation.
......
import numpy as np
import megengine as mge
from megengine import autodiff as ad
def test_attach_in_with_block():
a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
b = a * 3
g.attach(b)
c = b + 1
g.backward(c)
assert int(b.grad.numpy()) == 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册