提交 c7acba41 编写于 作者: M Megvii Engine Team

refactor(mge/optimizer): refine gradmanager api, record = __enter__

GitOrigin-RevId: 5376177237c7fab5ffc9f7fe61d13cb737f7866e
上级 8c482b67
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..tensor import tensor from ..tensor import tensor
...@@ -21,7 +22,11 @@ class GradManager: ...@@ -21,7 +22,11 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = dict() self._gradients = dict()
def register(self, params, callbacks=[]): def register(self, params, callbacks=None):
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
for p in params: for p in params:
self._param_dict[id(p)] = p self._param_dict[id(p)] = p
for cb in callbacks: for cb in callbacks:
...@@ -62,37 +67,37 @@ class GradManager: ...@@ -62,37 +67,37 @@ class GradManager:
else: else:
param.grad += grad param.grad += grad
finally: finally:
self._grad = None self._stop_record()
self._gradients = dict()
backwarding_grad_manager = cache backwarding_grad_manager = cache
def record(self): def __enter__(self):
@contextmanager if self._recording:
def recorder(): return self
grad = Grad() grad = Grad()
if self._recording: self._recording = True
raise RuntimeError("already recording!") self._grad = grad
try: for param_id in self._param_dict.keys():
self._recording = True param_wrapper = self._param_dict[param_id]
self._grad = grad callbacks = self._call_back_dict[param_id]
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def callback( def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
param, grad, callbacks=callbacks, p=param_wrapper, gm=self ret = grad
): for cb in callbacks:
ret = grad ret = cb(param, ret)
for cb in callbacks: gm._gradients[id(p)] = ret
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad.wrt(param_wrapper, callback=callback) grad.wrt(param_wrapper, callback=callback)
with grad: grad.__enter__()
yield return self
finally:
self._recording = False def __exit__(self, exc_type, exc_val, exc_tb):
self._grad = None self._stop_record()
self._gradients = dict()
record = __enter__
return recorder() def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._recording = False
self._grad = None
self._gradients = dict()
...@@ -10,6 +10,7 @@ import functools ...@@ -10,6 +10,7 @@ import functools
import multiprocessing as mp import multiprocessing as mp
from collections import defaultdict from collections import defaultdict
from typing import Callable from typing import Callable
from weakref import WeakSet
import numpy as np import numpy as np
...@@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast ...@@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed from .group import WORLD, group_barrier, is_distributed
class FakeTensor(Future): class TensorFuture(Future):
def device(self): def device(self):
raise "Sorry, this tensor is not ready" raise "Sorry, this tensor is not ready"
...@@ -77,7 +78,7 @@ class AllreduceCallback: ...@@ -77,7 +78,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"] assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._marked_gm = set() self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024 self._param_pack_thd = 10 * 1024 * 1024
self._reset() self._reset()
...@@ -107,7 +108,7 @@ class AllreduceCallback: ...@@ -107,7 +108,7 @@ class AllreduceCallback:
gm._register_after_backward_callback(self._flush) gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm) self._marked_gm.add(gm)
self._params.append(param) self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False) self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device) self._grad_origin_device[param] = str(grad.device)
......
...@@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta):
params.append(param) params.append(param)
return params return params
def step(self): def step(self, clear_grad=False):
r"""Performs a single optimization step. r"""Performs a single optimization step.
""" """
...@@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta): ...@@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead." "Please use a list instead."
) )
self._updates(group) self._updates(group)
if clear_grad:
self.clear_grad()
def clear_grad(self): def clear_grad(self):
r"""Clear the grad buffer. r"""Clear the grad buffer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册