提交 66671006 编写于 作者: M Megvii Engine Team

feat(mge): use weakref for GradManger.attach

GitOrigin-RevId: 6df336c3c16f2c93359c9687ea1a39de7f591354
上级 75ca5bfe
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable
......@@ -16,6 +17,10 @@ def get_backwarding_grad_manager():
return backwarding_grad_manager
class AttachSpec:
__slots__ = "tensor", "callbacks"
class GradManager:
r"""
GradManager manages auto differentiation and all resources required to perform it.
......@@ -64,14 +69,13 @@ class GradManager:
"""
def __init__(self):
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
self._attach_specs = {} # id(Tensor) -> AttachSpec
self._recording = False
self._grad = None
self._after_backward_callback = []
self._gradients = dict()
self._gradients = {}
def attach(self, params: list, callbacks=None):
def attach(self, tensors: list, callbacks=None):
r"""
Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:
......@@ -89,22 +93,39 @@ class GradManager:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
if isinstance(params, Tensor):
params = [params]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
if self._grad is not None:
for p in params:
self._record_param(id(p))
if isinstance(tensors, Tensor):
tensors = [tensors]
def make_spec(tensor):
selfref = weakref.ref(self)
key = id(tensor)
def deleter(_):
self = selfref()
if self is not None:
del self._attach_specs[key]
spec = AttachSpec()
spec.tensor = weakref.ref(tensor, deleter)
spec.callbacks = []
return spec
for x in tensors:
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:
spec = make_spec(x)
self._attach_specs[id(x)] = spec
spec.callbacks.extend(callbacks)
if new_attach and self._recording:
self._do_record(spec)
return self
def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
def backward(self, ys=None, dys=None):
def backward(self, y=None, dy=None):
r"""
Performs back-propagation and computes gradients.
......@@ -135,14 +156,16 @@ class GradManager:
self._grad(ys, dys)
for callback in self._after_backward_callback:
callback()
for p, grad in self._gradients.items():
for id_, grad in self._gradients.items():
if isinstance(grad, Future):
grad = grad.get()
param = self._param_dict[p]
if param.grad is None:
param.grad = grad
else:
param.grad += grad
spec = self._attach_specs.get(id_)
tensor = spec and spec.tensor()
if tensor is not None:
if tensor.grad is None:
tensor.grad = grad
else:
tensor.grad += grad
finally:
self.release()
backwarding_grad_manager = cache
......@@ -156,22 +179,22 @@ class GradManager:
grad = Grad()
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
self._record_param(param_id)
for spec in self._attach_specs.values():
self._do_record(spec)
grad.__enter__()
def _record_param(self, param_id):
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def _do_record(self, spec):
tensor = spec.tensor()
if tensor is None:
return
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
def callback(_, grad, callbacks=spec.callbacks):
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad = cb(tensor, grad)
self._gradients[id(tensor)] = grad
# NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback)
self._grad.wrt(tensor, callback=callback)
def release(self):
r"""
......
......@@ -224,6 +224,7 @@ class AllreduceCallback:
self._packing_size[dtype] = 0
def __call__(self, param, grad):
param = param.__wrapped__
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._marked_gm:
......
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import weakref
import numpy as np
import pytest
......@@ -59,6 +60,39 @@ def test_attach_in_with_block():
assert int(b.grad.numpy()) == 1
def test_attach_temporary():
w = mge.Parameter(2.0)
gm = GradManager()
gm.attach(w)
def cb(x, g):
assert x is ref()
cb.called = True
for i in range(3):
with gm:
cb.called = False
x = mge.Tensor(i, dtype="float32")
gm.attach(x, callbacks=cb)
ref = weakref.ref(x)
y = x * w
gm.backward(y)
assert cb.called
del x
assert ref() is None
# NOTE: does not guarantee timely release when recording
# for i in range(3):
# with gm:
# x = mge.Tensor(i, dtype='float32')
# gm.attach(x)
# ref = weakref.ref(x)
# y = x * w
# del x
# assert ref() is None
# gm.backward(y)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册