From 6667100638cdffb50b33d9f1381c257f8dfe877e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 19 Nov 2020 18:14:45 +0800 Subject: [PATCH] feat(mge): use weakref for GradManger.attach GitOrigin-RevId: 6df336c3c16f2c93359c9687ea1a39de7f591354 --- .../python/megengine/autodiff/grad_manager.py | 83 ++++++++++++------- .../python/megengine/distributed/helper.py | 1 + .../test/unit/autodiff/test_grad_manger.py | 34 ++++++++ 3 files changed, 88 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index db46b12e1..609dc2ac6 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,3 +1,4 @@ +import weakref from collections import defaultdict from contextlib import contextmanager from typing import Callable @@ -16,6 +17,10 @@ def get_backwarding_grad_manager(): return backwarding_grad_manager +class AttachSpec: + __slots__ = "tensor", "callbacks" + + class GradManager: r""" GradManager manages auto differentiation and all resources required to perform it. @@ -64,14 +69,13 @@ class GradManager: """ def __init__(self): - self._call_back_dict = defaultdict(list) - self._param_dict = dict() + self._attach_specs = {} # id(Tensor) -> AttachSpec self._recording = False self._grad = None self._after_backward_callback = [] - self._gradients = dict() + self._gradients = {} - def attach(self, params: list, callbacks=None): + def attach(self, tensors: list, callbacks=None): r""" Registers parameters that gradients should be calculated with respect to. Callback Functions should have a signature like this: @@ -89,22 +93,39 @@ class GradManager: callbacks = [] if isinstance(callbacks, Callable): callbacks = [callbacks] - if isinstance(params, Tensor): - params = [params] - for p in params: - self._param_dict[id(p)] = p - for cb in callbacks: - self._call_back_dict[id(p)].append(cb) - if self._grad is not None: - for p in params: - self._record_param(id(p)) + if isinstance(tensors, Tensor): + tensors = [tensors] + + def make_spec(tensor): + selfref = weakref.ref(self) + key = id(tensor) + + def deleter(_): + self = selfref() + if self is not None: + del self._attach_specs[key] + + spec = AttachSpec() + spec.tensor = weakref.ref(tensor, deleter) + spec.callbacks = [] + return spec + + for x in tensors: + spec = self._attach_specs.get(id(x)) + new_attach = spec is None + if spec is None: + spec = make_spec(x) + self._attach_specs[id(x)] = spec + spec.callbacks.extend(callbacks) + if new_attach and self._recording: + self._do_record(spec) return self def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self - def backward(self, ys=None, dys=None): + def backward(self, y=None, dy=None): r""" Performs back-propagation and computes gradients. @@ -135,14 +156,16 @@ class GradManager: self._grad(ys, dys) for callback in self._after_backward_callback: callback() - for p, grad in self._gradients.items(): + for id_, grad in self._gradients.items(): if isinstance(grad, Future): grad = grad.get() - param = self._param_dict[p] - if param.grad is None: - param.grad = grad - else: - param.grad += grad + spec = self._attach_specs.get(id_) + tensor = spec and spec.tensor() + if tensor is not None: + if tensor.grad is None: + tensor.grad = grad + else: + tensor.grad += grad finally: self.release() backwarding_grad_manager = cache @@ -156,22 +179,22 @@ class GradManager: grad = Grad() self._recording = True self._grad = grad - for param_id in self._param_dict.keys(): - self._record_param(param_id) + for spec in self._attach_specs.values(): + self._do_record(spec) grad.__enter__() - def _record_param(self, param_id): - param_wrapper = self._param_dict[param_id] - callbacks = self._call_back_dict[param_id] + def _do_record(self, spec): + tensor = spec.tensor() + if tensor is None: + return - def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): - ret = grad + def callback(_, grad, callbacks=spec.callbacks): for cb in callbacks: - ret = cb(param, ret) - gm._gradients[id(p)] = ret + grad = cb(tensor, grad) + self._gradients[id(tensor)] = grad # NOTE: override prev callback wrt when called serval times - self._grad.wrt(param_wrapper, callback=callback) + self._grad.wrt(tensor, callback=callback) def release(self): r""" diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0755ea2ce..d990cc263 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -224,6 +224,7 @@ class AllreduceCallback: self._packing_size[dtype] = 0 def __call__(self, param, grad): + param = param.__wrapped__ gm = get_backwarding_grad_manager() assert isinstance(gm, GradManager) if gm not in self._marked_gm: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 947fa5203..8e0cc901e 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import platform +import weakref import numpy as np import pytest @@ -59,6 +60,39 @@ def test_attach_in_with_block(): assert int(b.grad.numpy()) == 1 +def test_attach_temporary(): + w = mge.Parameter(2.0) + gm = GradManager() + gm.attach(w) + + def cb(x, g): + assert x is ref() + cb.called = True + + for i in range(3): + with gm: + cb.called = False + x = mge.Tensor(i, dtype="float32") + gm.attach(x, callbacks=cb) + ref = weakref.ref(x) + y = x * w + gm.backward(y) + assert cb.called + del x + assert ref() is None + + # NOTE: does not guarantee timely release when recording + # for i in range(3): + # with gm: + # x = mge.Tensor(i, dtype='float32') + # gm.attach(x) + # ref = weakref.ref(x) + # y = x * w + # del x + # assert ref() is None + # gm.backward(y) + + @pytest.mark.skipif( platform.system() == "Darwin", reason="do not imp GPU mode at macos now" ) -- GitLab