提交 c7acba41 编写于 作者: M Megvii Engine Team

refactor(mge/optimizer): refine gradmanager api, record = __enter__

GitOrigin-RevId: 5376177237c7fab5ffc9f7fe61d13cb737f7866e
上级 8c482b67
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable
from ..core.autodiff.grad import Grad
from ..tensor import tensor
......@@ -21,7 +22,11 @@ class GradManager:
self._after_backward_callback = []
self._gradients = dict()
def register(self, params, callbacks=[]):
def register(self, params, callbacks=None):
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
callbacks = [callbacks]
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
......@@ -62,37 +67,37 @@ class GradManager:
else:
param.grad += grad
finally:
self._grad = None
self._gradients = dict()
self._stop_record()
backwarding_grad_manager = cache
def record(self):
@contextmanager
def recorder():
grad = Grad()
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def __enter__(self):
if self._recording:
return self
grad = Grad()
self._recording = True
self._grad = grad
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def callback(
param, grad, callbacks=callbacks, p=param_wrapper, gm=self
):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad.wrt(param_wrapper, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
self._gradients = dict()
grad.wrt(param_wrapper, callback=callback)
grad.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_record()
record = __enter__
return recorder()
def _stop_record(self):
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._recording = False
self._grad = None
self._gradients = dict()
......@@ -10,6 +10,7 @@ import functools
import multiprocessing as mp
from collections import defaultdict
from typing import Callable
from weakref import WeakSet
import numpy as np
......@@ -23,7 +24,7 @@ from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
class FakeTensor(Future):
class TensorFuture(Future):
def device(self):
raise "Sorry, this tensor is not ready"
......@@ -77,7 +78,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method
self._group = group
self._marked_gm = set()
self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()
......@@ -107,7 +108,7 @@ class AllreduceCallback:
gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)
......
......@@ -140,7 +140,7 @@ class Optimizer(metaclass=ABCMeta):
params.append(param)
return params
def step(self):
def step(self, clear_grad=False):
r"""Performs a single optimization step.
"""
......@@ -152,6 +152,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead."
)
self._updates(group)
if clear_grad:
self.clear_grad()
def clear_grad(self):
r"""Clear the grad buffer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册