提交 2627e1f7 编写于 作者: M Megvii Engine Team

fix(mge/grad_manager): allow multiple calls of `release`

GitOrigin-RevId: 38ca4c78ff1fb7c8b76716a6fe347333b33478ef
上级 67a543f3
...@@ -3,9 +3,12 @@ from contextlib import contextmanager ...@@ -3,9 +3,12 @@ from contextlib import contextmanager
from typing import Callable from typing import Callable
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..tensor import Tensor, tensor from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
logger = get_logger(__name__)
backwarding_grad_manager = None backwarding_grad_manager = None
...@@ -67,7 +70,7 @@ class GradManager: ...@@ -67,7 +70,7 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = dict()
def attach(self, params, callbacks=None): def attach(self, params: list, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to. r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this: Callback Functions should have a signature like this:
...@@ -77,7 +80,7 @@ class GradManager: ...@@ -77,7 +80,7 @@ class GradManager:
# do something # do something
return grad return grad
:param params: registered parameters :param params: to be registered parameters
:param callbacks: list of callback functions :param callbacks: list of callback functions
""" """
if callbacks is None: if callbacks is None:
...@@ -95,6 +98,20 @@ class GradManager: ...@@ -95,6 +98,20 @@ class GradManager:
self._record_param(id(p)) self._record_param(id(p))
return self return self
def detach(self, params: list):
r"""Remove specific registered parameters and callback functions.
:param params: registered parameters
"""
if isinstance(params, Tensor):
params = [params]
for idx, param in enumerate(params):
if id(param) in self._param_dict:
self._param_dict.pop(id(param))
self._call_back_dict.pop(id(param))
else:
logger.warning("params with index {} is not attached.".format(idx))
def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self
...@@ -136,7 +153,7 @@ class GradManager: ...@@ -136,7 +153,7 @@ class GradManager:
else: else:
param.grad += grad param.grad += grad
finally: finally:
self._stop_record() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
def record(self): def record(self):
...@@ -167,15 +184,10 @@ class GradManager: ...@@ -167,15 +184,10 @@ class GradManager:
def release(self): def release(self):
r"""Stops recording and releases resources for gradients calculation. r"""Stops recording and releases resources for gradients calculation.
""" """
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()
def _stop_record(self):
if self._grad is not None: if self._grad is not None:
self._grad.__exit__(None, None, None) self._grad.__exit__(None, None, None)
self._recording = False
self._grad = None self._grad = None
self._recording = False
self._gradients = dict() self._gradients = dict()
def __enter__(self): def __enter__(self):
...@@ -183,4 +195,4 @@ class GradManager: ...@@ -183,4 +195,4 @@ class GradManager:
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record() self.release()
...@@ -85,11 +85,8 @@ class Tensor(_Tensor): ...@@ -85,11 +85,8 @@ class Tensor(_Tensor):
def detach(self): def detach(self):
r""" r"""
Returns a new tensor which is treated as constant during backward gradient calcuation, Returns a new tensor sharing the same data memory, which is treated as a constant
i.e. its gradient is zero. during backward gradient calcuation, i.e. its gradient is zero.
:param inp: input tensor
""" """
Wrapper = type(self) Wrapper = type(self)
Tensor = type(self.__wrapped__) Tensor = type(self.__wrapped__)
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest
import megengine as mge import megengine as mge
from megengine import autodiff as ad import megengine.functional as F
from megengine.autodiff import GradManager
def test_basic():
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
b = mge.tensor(-1.0)
gm = GradManager().attach([w, b])
gm.record()
p = F.matmul(x, w)
y = p + b
gm.backward(y)
gm.release() # is not necessary
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])
w.grad = None
b.grad = None
with gm:
p = F.matmul(x, w)
y = p + b
gm.backward(y)
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])
def test_attach_in_with_block(): def test_attach_in_with_block():
a = mge.Parameter([1.0]) a = mge.Parameter([1.0])
g = ad.GradManager() gm = GradManager()
with g: with gm:
b = a * 3 b = a * 3
g.attach(b) gm.attach(b)
c = b + 1 c = b + 1
g.backward(c) gm.backward(c)
assert int(b.grad.numpy()) == 1 assert int(b.grad.numpy()) == 1
...@@ -27,8 +27,6 @@ def test_qat_convbn2d(): ...@@ -27,8 +27,6 @@ def test_qat_convbn2d():
disable_fake_quant(qat_module) disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs) normal_outputs = module(inputs)
# import pdb
# pdb.set_trace()
qat_outputs = qat_module(inputs) qat_outputs = qat_module(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册