提交 60702667 编写于 作者: M Megvii Engine Team

refactor(mge/grad_manager): refactor gradmanager, add allreduce callback

GitOrigin-RevId: 086e2871e8141bc2d6c4067b7c42eff85330ebca
上级 3f2eac2f
from contextlib import contextmanager
from ..core.autodiff.grad import Grad
from ..tensor import tensor
class GradManager:
def __init__(self):
self._call_back_pair = []
self._recording = False
self._grad = None
def register(self, params, callback=None):
self._call_back_pair.append([params, callback])
def backward(self, ys, dys=None):
if not self._recording:
raise RuntimeError(
"no computation history. "
"did you forget record() or "
"call a method that clears the history?"
)
assert self._grad is not None
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
dys = [tensor(1).broadcast(y.shape) for y in ys]
if not isinstance(dys, (tuple, list)):
dys = [dys]
try:
self._grad(ys, dys)
finally:
self._grad = None
def record(self):
@contextmanager
def recorder():
grad = Grad()
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
grad.wrt(*params, callback=callbacks)
with grad:
yield
finally:
self._recording = False
self._grad = None
return recorder()
......@@ -260,9 +260,13 @@ class Grad:
cache[v] = g
if last_written_to[v] == (seqno, i):
if v.callback:
v.callback(
grad = v.callback(
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
)
if getattr(v.owner(), "grad", None) is None:
v.owner().grad = grad
else:
v.owner().grad += grad
if v.opnode is None:
# won't read by backward, mark consumed
cache[v] = None
......
......@@ -19,7 +19,7 @@ from .group import (
is_distributed,
new_group,
)
from .helper import synchronized
from .helper import bcast_params_, make_allreduce_cb, synchronized
from .launcher import launcher
from .server import Client, Server
from .util import get_free_ports
......@@ -12,7 +12,8 @@ from typing import Callable
from megengine.device import get_device_count
from .group import group_barrier, is_distributed
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
def synchronized(func: Callable):
......@@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str):
p.start()
p.join()
return q.get()
def bcast_params_(params, group):
for p in params:
p._reset(broadcast(p, group))
class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
self._reduce_method = reduce_method
self._group = group
def __call__(self, param, grad):
ret = all_reduce_sum(grad, self._group)
if self._reduce_method == "MEAN":
ret = ret / self._group.size
return ret
make_allreduce_cb = AllreduceCallback
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册