diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index ba8d8a9e7399540483a0f3fa1b4fac005ac509e6..a64b4fd1ac2b729a5c751acd350e0cc02e83883c 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -3,9 +3,12 @@ from contextlib import contextmanager from typing import Callable from ..core.autodiff.grad import Grad -from ..tensor import Tensor, tensor +from ..logger import get_logger +from ..tensor import Tensor from ..utils.future import Future +logger = get_logger(__name__) + backwarding_grad_manager = None @@ -67,7 +70,7 @@ class GradManager: self._after_backward_callback = [] self._gradients = dict() - def attach(self, params, callbacks=None): + def attach(self, params: list, callbacks=None): r"""Registers parameters that gradients should be calculated with respect to. Callback Functions should have a signature like this: @@ -77,7 +80,7 @@ class GradManager: # do something return grad - :param params: registered parameters + :param params: to be registered parameters :param callbacks: list of callback functions """ if callbacks is None: @@ -95,6 +98,20 @@ class GradManager: self._record_param(id(p)) return self + def detach(self, params: list): + r"""Remove specific registered parameters and callback functions. + + :param params: registered parameters + """ + if isinstance(params, Tensor): + params = [params] + for idx, param in enumerate(params): + if id(param) in self._param_dict: + self._param_dict.pop(id(param)) + self._call_back_dict.pop(id(param)) + else: + logger.warning("params with index {} is not attached.".format(idx)) + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self @@ -136,7 +153,7 @@ class GradManager: else: param.grad += grad finally: - self._stop_record() + self.release() backwarding_grad_manager = cache def record(self): @@ -167,15 +184,10 @@ class GradManager: def release(self): r"""Stops recording and releases resources for gradients calculation. """ - if not self._recording: - raise RuntimeError("not recording") - self._stop_record() - - def _stop_record(self): if self._grad is not None: self._grad.__exit__(None, None, None) + self._grad = None self._recording = False - self._grad = None self._gradients = dict() def __enter__(self): @@ -183,4 +195,4 @@ class GradManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - self._stop_record() + self.release() diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 5d13530a0452e728974c33b0d970233d64925cda..571489995c9f2df9ad2b232c11d432faed346d52 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -85,11 +85,8 @@ class Tensor(_Tensor): def detach(self): r""" - Returns a new tensor which is treated as constant during backward gradient calcuation, - i.e. its gradient is zero. - - :param inp: input tensor - + Returns a new tensor sharing the same data memory, which is treated as a constant + during backward gradient calcuation, i.e. its gradient is zero. """ Wrapper = type(self) Tensor = type(self.__wrapped__) diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index d1afa203f0ebed58c9aadfb8a448c94b7451f6ad..372b3816c6e8fa33ff4ac40f9bf94f4ea909ac36 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -1,15 +1,51 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np +import pytest import megengine as mge -from megengine import autodiff as ad +import megengine.functional as F +from megengine.autodiff import GradManager + + +def test_basic(): + x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) + w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1) + b = mge.tensor(-1.0) + + gm = GradManager().attach([w, b]) + gm.record() + + p = F.matmul(x, w) + y = p + b + + gm.backward(y) + gm.release() # is not necessary + np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) + np.testing.assert_equal(b.grad.numpy(), [1]) + + w.grad = None + b.grad = None + with gm: + p = F.matmul(x, w) + y = p + b + gm.backward(y) + + np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) + np.testing.assert_equal(b.grad.numpy(), [1]) def test_attach_in_with_block(): a = mge.Parameter([1.0]) - g = ad.GradManager() - with g: + gm = GradManager() + with gm: b = a * 3 - g.attach(b) + gm.attach(b) c = b + 1 - g.backward(c) + gm.backward(c) assert int(b.grad.numpy()) == 1 diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 718d6acbb7491e09b94b8581fe14632a5c86c2ea..34176a3a8e5b9b70e4c6fc37de8e72f9b548db9d 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -27,8 +27,6 @@ def test_qat_convbn2d(): disable_fake_quant(qat_module) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) - # import pdb - # pdb.set_trace() qat_outputs = qat_module(inputs) np.testing.assert_allclose( normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6