提交 66671006 编写于 作者: M Megvii Engine Team

feat(mge): use weakref for GradManger.attach

GitOrigin-RevId: 6df336c3c16f2c93359c9687ea1a39de7f591354
上级 75ca5bfe
import weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable from typing import Callable
...@@ -16,6 +17,10 @@ def get_backwarding_grad_manager(): ...@@ -16,6 +17,10 @@ def get_backwarding_grad_manager():
return backwarding_grad_manager return backwarding_grad_manager
class AttachSpec:
__slots__ = "tensor", "callbacks"
class GradManager: class GradManager:
r""" r"""
GradManager manages auto differentiation and all resources required to perform it. GradManager manages auto differentiation and all resources required to perform it.
...@@ -64,14 +69,13 @@ class GradManager: ...@@ -64,14 +69,13 @@ class GradManager:
""" """
def __init__(self): def __init__(self):
self._call_back_dict = defaultdict(list) self._attach_specs = {} # id(Tensor) -> AttachSpec
self._param_dict = dict()
self._recording = False self._recording = False
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = {}
def attach(self, params: list, callbacks=None): def attach(self, tensors: list, callbacks=None):
r""" r"""
Registers parameters that gradients should be calculated with respect to. Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this: Callback Functions should have a signature like this:
...@@ -89,22 +93,39 @@ class GradManager: ...@@ -89,22 +93,39 @@ class GradManager:
callbacks = [] callbacks = []
if isinstance(callbacks, Callable): if isinstance(callbacks, Callable):
callbacks = [callbacks] callbacks = [callbacks]
if isinstance(params, Tensor): if isinstance(tensors, Tensor):
params = [params] tensors = [tensors]
for p in params:
self._param_dict[id(p)] = p def make_spec(tensor):
for cb in callbacks: selfref = weakref.ref(self)
self._call_back_dict[id(p)].append(cb) key = id(tensor)
if self._grad is not None:
for p in params: def deleter(_):
self._record_param(id(p)) self = selfref()
if self is not None:
del self._attach_specs[key]
spec = AttachSpec()
spec.tensor = weakref.ref(tensor, deleter)
spec.callbacks = []
return spec
for x in tensors:
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:
spec = make_spec(x)
self._attach_specs[id(x)] = spec
spec.callbacks.extend(callbacks)
if new_attach and self._recording:
self._do_record(spec)
return self return self
def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self
def backward(self, ys=None, dys=None): def backward(self, y=None, dy=None):
r""" r"""
Performs back-propagation and computes gradients. Performs back-propagation and computes gradients.
...@@ -135,14 +156,16 @@ class GradManager: ...@@ -135,14 +156,16 @@ class GradManager:
self._grad(ys, dys) self._grad(ys, dys)
for callback in self._after_backward_callback: for callback in self._after_backward_callback:
callback() callback()
for p, grad in self._gradients.items(): for id_, grad in self._gradients.items():
if isinstance(grad, Future): if isinstance(grad, Future):
grad = grad.get() grad = grad.get()
param = self._param_dict[p] spec = self._attach_specs.get(id_)
if param.grad is None: tensor = spec and spec.tensor()
param.grad = grad if tensor is not None:
else: if tensor.grad is None:
param.grad += grad tensor.grad = grad
else:
tensor.grad += grad
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
...@@ -156,22 +179,22 @@ class GradManager: ...@@ -156,22 +179,22 @@ class GradManager:
grad = Grad() grad = Grad()
self._recording = True self._recording = True
self._grad = grad self._grad = grad
for param_id in self._param_dict.keys(): for spec in self._attach_specs.values():
self._record_param(param_id) self._do_record(spec)
grad.__enter__() grad.__enter__()
def _record_param(self, param_id): def _do_record(self, spec):
param_wrapper = self._param_dict[param_id] tensor = spec.tensor()
callbacks = self._call_back_dict[param_id] if tensor is None:
return
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): def callback(_, grad, callbacks=spec.callbacks):
ret = grad
for cb in callbacks: for cb in callbacks:
ret = cb(param, ret) grad = cb(tensor, grad)
gm._gradients[id(p)] = ret self._gradients[id(tensor)] = grad
# NOTE: override prev callback wrt when called serval times # NOTE: override prev callback wrt when called serval times
self._grad.wrt(param_wrapper, callback=callback) self._grad.wrt(tensor, callback=callback)
def release(self): def release(self):
r""" r"""
......
...@@ -224,6 +224,7 @@ class AllreduceCallback: ...@@ -224,6 +224,7 @@ class AllreduceCallback:
self._packing_size[dtype] = 0 self._packing_size[dtype] = 0
def __call__(self, param, grad): def __call__(self, param, grad):
param = param.__wrapped__
gm = get_backwarding_grad_manager() gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._marked_gm: if gm not in self._marked_gm:
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform import platform
import weakref
import numpy as np import numpy as np
import pytest import pytest
...@@ -59,6 +60,39 @@ def test_attach_in_with_block(): ...@@ -59,6 +60,39 @@ def test_attach_in_with_block():
assert int(b.grad.numpy()) == 1 assert int(b.grad.numpy()) == 1
def test_attach_temporary():
w = mge.Parameter(2.0)
gm = GradManager()
gm.attach(w)
def cb(x, g):
assert x is ref()
cb.called = True
for i in range(3):
with gm:
cb.called = False
x = mge.Tensor(i, dtype="float32")
gm.attach(x, callbacks=cb)
ref = weakref.ref(x)
y = x * w
gm.backward(y)
assert cb.called
del x
assert ref() is None
# NOTE: does not guarantee timely release when recording
# for i in range(3):
# with gm:
# x = mge.Tensor(i, dtype='float32')
# gm.attach(x)
# ref = weakref.ref(x)
# y = x * w
# del x
# assert ref() is None
# gm.backward(y)
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册