From c7acba41fc3ca5a22d2a0b11b5d4b05e6d66178f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 13:03:45 +0800 Subject: [PATCH] refactor(mge/optimizer): refine gradmanager api, record = __enter__ GitOrigin-RevId: 5376177237c7fab5ffc9f7fe61d13cb737f7866e --- .../python/megengine/autodiff/grad_manager.py | 65 ++++++++++--------- .../python/megengine/distributed/helper.py | 7 +- .../python/megengine/optimizer/optimizer.py | 4 +- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index fefd2a682..f1790c3f5 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,5 +1,6 @@ from collections import defaultdict from contextlib import contextmanager +from typing import Callable from ..core.autodiff.grad import Grad from ..tensor import tensor @@ -21,7 +22,11 @@ class GradManager: self._after_backward_callback = [] self._gradients = dict() - def register(self, params, callbacks=[]): + def register(self, params, callbacks=None): + if callbacks is None: + callbacks = [] + if isinstance(callbacks, Callable): + callbacks = [callbacks] for p in params: self._param_dict[id(p)] = p for cb in callbacks: @@ -62,37 +67,37 @@ class GradManager: else: param.grad += grad finally: - self._grad = None - self._gradients = dict() + self._stop_record() backwarding_grad_manager = cache - def record(self): - @contextmanager - def recorder(): - grad = Grad() - if self._recording: - raise RuntimeError("already recording!") - try: - self._recording = True - self._grad = grad - for param_id in self._param_dict.keys(): - param_wrapper = self._param_dict[param_id] - callbacks = self._call_back_dict[param_id] + def __enter__(self): + if self._recording: + return self + grad = Grad() + self._recording = True + self._grad = grad + for param_id in self._param_dict.keys(): + param_wrapper = self._param_dict[param_id] + callbacks = self._call_back_dict[param_id] - def callback( - param, grad, callbacks=callbacks, p=param_wrapper, gm=self - ): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - gm._gradients[id(p)] = ret + def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + gm._gradients[id(p)] = ret - grad.wrt(param_wrapper, callback=callback) - with grad: - yield - finally: - self._recording = False - self._grad = None - self._gradients = dict() + grad.wrt(param_wrapper, callback=callback) + grad.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stop_record() + + record = __enter__ - return recorder() + def _stop_record(self): + if self._grad is not None: + self._grad.__exit__(None, None, None) + self._recording = False + self._grad = None + self._gradients = dict() diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index aebc9b087..978c2854b 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -10,6 +10,7 @@ import functools import multiprocessing as mp from collections import defaultdict from typing import Callable +from weakref import WeakSet import numpy as np @@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast from .group import WORLD, group_barrier, is_distributed -class FakeTensor(Future): +class TensorFuture(Future): def device(self): raise "Sorry, this tensor is not ready" @@ -77,7 +78,7 @@ class AllreduceCallback: assert reduce_method in ["sum", "mean"] self._reduce_method = reduce_method self._group = group - self._marked_gm = set() + self._marked_gm = WeakSet() self._param_pack_thd = 10 * 1024 * 1024 self._reset() @@ -107,7 +108,7 @@ class AllreduceCallback: gm._register_after_backward_callback(self._flush) self._marked_gm.add(gm) self._params.append(param) - self._futures_dict[param] = FakeTensor(ack=False) + self._futures_dict[param] = TensorFuture(ack=False) self._gradients_dict[param] = grad self._grad_origin_device[param] = str(grad.device) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index e4205bde9..2063a5851 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta): params.append(param) return params - def step(self): + def step(self, clear_grad=False): r"""Performs a single optimization step. """ @@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta): "Please use a list instead." ) self._updates(group) + if clear_grad: + self.clear_grad() def clear_grad(self): r"""Clear the grad buffer. -- GitLab