提交 b3278229 编写于 作者: M Megvii Engine Team

feat(mge/grad_manager): add `clear_grad` method for GradManager

GitOrigin-RevId: aa9540e09018697110b672f3772473b68305751c
上级 2627e1f7
......@@ -112,6 +112,14 @@ class GradManager:
else:
logger.warning("params with index {} is not attached.".format(idx))
def clear_grad(self):
r"""
For advanced usage: set the grad attribute to None for registered parameters.
It could be more convenient when there is more than one Optimizer.
"""
for param in self._param_dict.values():
param.grad = None
def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
......
......@@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta):
if not isinstance(param, Parameter):
raise TypeError(
"optimizer can only optimize Parameters, but one of the params is "
+ type(param)
+ str(type(param))
)
for name, default in self._defaults.items():
......@@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta):
def clear_grad(self):
r"""Set the grad attribute to None for all parameters.
"""
for param_group in self.param_groups:
for param in param_group["params"]:
......
......@@ -29,8 +29,7 @@ def test_basic():
np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
np.testing.assert_equal(b.grad.numpy(), [1])
w.grad = None
b.grad = None
gm.clear_grad()
with gm:
p = F.matmul(x, w)
y = p + b
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册