From b327822994f72b2df0ed0f1e41772bc6bf53556a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 29 Oct 2020 11:37:57 +0800 Subject: [PATCH] feat(mge/grad_manager): add `clear_grad` method for GradManager GitOrigin-RevId: aa9540e09018697110b672f3772473b68305751c --- imperative/python/megengine/autodiff/grad_manager.py | 8 ++++++++ imperative/python/megengine/optimizer/optimizer.py | 3 +-- imperative/python/test/unit/autodiff/test_grad_manger.py | 3 +-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index a64b4fd1a..54575a705 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -112,6 +112,14 @@ class GradManager: else: logger.warning("params with index {} is not attached.".format(idx)) + def clear_grad(self): + r""" + For advanced usage: set the grad attribute to None for registered parameters. + It could be more convenient when there is more than one Optimizer. + """ + for param in self._param_dict.values(): + param.grad = None + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index cf869cdff..fc9ebf0b3 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -91,7 +91,7 @@ class Optimizer(metaclass=ABCMeta): if not isinstance(param, Parameter): raise TypeError( "optimizer can only optimize Parameters, but one of the params is " - + type(param) + + str(type(param)) ) for name, default in self._defaults.items(): @@ -159,7 +159,6 @@ class Optimizer(metaclass=ABCMeta): def clear_grad(self): r"""Set the grad attribute to None for all parameters. - """ for param_group in self.param_groups: for param in param_group["params"]: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 372b3816c..f54bd02a1 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -29,8 +29,7 @@ def test_basic(): np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]]) np.testing.assert_equal(b.grad.numpy(), [1]) - w.grad = None - b.grad = None + gm.clear_grad() with gm: p = F.matmul(x, w) y = p + b -- GitLab