提交 b3278229 编写于 作者: M Megvii Engine Team

feat(mge/grad_manager): add `clear_grad` method for GradManager

GitOrigin-RevId: aa9540e09018697110b672f3772473b68305751c
上级 2627e1f7
...@@ -112,6 +112,14 @@ class GradManager: ...@@ -112,6 +112,14 @@ class GradManager:
else: else:
logger.warning("params with index {} is not attached.".format(idx)) logger.warning("params with index {} is not attached.".format(idx))
def clear_grad(self):
r"""
For advanced usage: set the grad attribute to None for registered parameters.
It could be more convenient when there is more than one Optimizer.
"""
for param in self._param_dict.values():
param.grad = None
def _register_after_backward_callback(self, callback): def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self
......
...@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta):
if not isinstance(param, Parameter): if not isinstance(param, Parameter):
raise TypeError( raise TypeError(
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ type(param) + str(type(param))
) )
for name, default in self._defaults.items(): for name, default in self._defaults.items():
...@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): ...@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta):
def clear_grad(self): def clear_grad(self):
r"""Set the grad attribute to None for all parameters. r"""Set the grad attribute to None for all parameters.
""" """
for param_group in self.param_groups: for param_group in self.param_groups:
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -29,8 +29,7 @@ def test_basic(): ...@@ -29,8 +29,7 @@ def test_basic():
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1]) np.testing.assert_equal(b.grad.numpy(), [1])
w.grad = None gm.clear_grad()
b.grad = None
with gm: with gm:
p = F.matmul(x, w) p = F.matmul(x, w)
y = p + b y = p + b
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册