提交 3af10563 编写于 作者: M Megvii Engine Team

feat(mge/grad): attach grad immediately

GitOrigin-RevId: e3a168c03ab78aafabfe3d41b0e18c61ee07c256
上级 dc3c17ba
......@@ -3,7 +3,7 @@ from contextlib import contextmanager
from typing import Callable
from ..core.autodiff.grad import Grad
from ..tensor import tensor
from ..tensor import Tensor, tensor
from ..utils.future import Future
backwarding_grad_manager = None
......@@ -84,10 +84,15 @@ class GradManager:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
if isinstance(params, Tensor):
params = [params]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
if self._grad is not None:
for p in params:
self._record_param(id(p))
return self
def _register_after_backward_callback(self, callback):
......@@ -143,17 +148,21 @@ class GradManager:
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
self._record_param(param_id)
grad.__enter__()
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
def _record_param(self, param_id):
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
grad.wrt(param_wrapper, callback=callback)
grad.__enter__()
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
# NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback)
def release(self):
r"""Stops recording and releases resources for gradients calculation.
......
import numpy as np
import megengine as mge
from megengine import autodiff as ad
def test_attach_in_with_block():
a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
b = a * 3
g.attach(b)
c = b + 1
g.backward(c)
assert int(b.grad.numpy()) == 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册