提交 2627e1f7 编写于 作者: M Megvii Engine Team

fix(mge/grad_manager): allow multiple calls of `release`

GitOrigin-RevId: 38ca4c78ff1fb7c8b76716a6fe347333b33478ef
上级 67a543f3
......@@ -3,9 +3,12 @@ from contextlib import contextmanager
from typing import Callable
from ..core.autodiff.grad import Grad
from ..tensor import Tensor, tensor
from ..logger import get_logger
from ..tensor import Tensor
from ..utils.future import Future
logger = get_logger(__name__)
backwarding_grad_manager = None
......@@ -67,7 +70,7 @@ class GradManager:
self._after_backward_callback = []
self._gradients = dict()
def attach(self, params, callbacks=None):
def attach(self, params: list, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:
......@@ -77,7 +80,7 @@ class GradManager:
# do something
return grad
:param params: registered parameters
:param params: to be registered parameters
:param callbacks: list of callback functions
"""
if callbacks is None:
......@@ -95,6 +98,20 @@ class GradManager:
self._record_param(id(p))
return self
def detach(self, params: list):
r"""Remove specific registered parameters and callback functions.
:param params: registered parameters
"""
if isinstance(params, Tensor):
params = [params]
for idx, param in enumerate(params):
if id(param) in self._param_dict:
self._param_dict.pop(id(param))
self._call_back_dict.pop(id(param))
else:
logger.warning("params with index {} is not attached.".format(idx))
def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
......@@ -136,7 +153,7 @@ class GradManager:
else:
param.grad += grad
finally:
self._stop_record()
self.release()
backwarding_grad_manager = cache
def record(self):
......@@ -167,15 +184,10 @@ class GradManager:
def release(self):
r"""Stops recording and releases resources for gradients calculation.
"""
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()
def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._grad = None
self._recording = False
self._grad = None
self._gradients = dict()
def __enter__(self):
......@@ -183,4 +195,4 @@ class GradManager:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()
self.release()
......@@ -85,11 +85,8 @@ class Tensor(_Tensor):
def detach(self):
r"""
Returns a new tensor which is treated as constant during backward gradient calcuation,
i.e. its gradient is zero.
:param inp: input tensor
Returns a new tensor sharing the same data memory, which is treated as a constant
during backward gradient calcuation, i.e. its gradient is zero.
"""
Wrapper = type(self)
Tensor = type(self.__wrapped__)
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
import megengine as mge
from megengine import autodiff as ad
import megengine.functional as F
from megengine.autodiff import GradManager
def test_basic():
x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
b = mge.tensor(-1.0)
gm = GradManager().attach([w, b])
gm.record()
p = F.matmul(x, w)
y = p + b
gm.backward(y)
gm.release() # is not necessary
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])
w.grad = None
b.grad = None
with gm:
p = F.matmul(x, w)
y = p + b
gm.backward(y)
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])
def test_attach_in_with_block():
a = mge.Parameter([1.0])
g = ad.GradManager()
with g:
gm = GradManager()
with gm:
b = a * 3
g.attach(b)
gm.attach(b)
c = b + 1
g.backward(c)
gm.backward(c)
assert int(b.grad.numpy()) == 1
......@@ -27,8 +27,6 @@ def test_qat_convbn2d():
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
# import pdb
# pdb.set_trace()
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册