提交 b5016b9d 编写于 作者: M Megvii Engine Team

feat(mge/parampack): add parampack in allreduce callback

GitOrigin-RevId: 73d53eeba16a792a427f57783f382267f3d33d43
上级 5ae89c79
from collections import defaultdict
from contextlib import contextmanager
from ..core.autodiff.grad import Grad
from ..distributed.util import Future
from ..tensor import tensor
backwarding_grad_manager = None
def get_backwarding_grad_manager():
return backwarding_grad_manager
class GradManager:
def __init__(self):
self._call_back_pair = []
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
self._recording = False
self._grad = None
self._after_backward_callback = []
self._gradients = dict()
def register(self, params, callbacks=[]):
for p in params:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
def register(self, params, callbacks=None):
self._call_back_pair.append([list(params), callbacks or []])
def register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
def backward(self, ys, dys=None):
global backwarding_grad_manager
cache = backwarding_grad_manager
backwarding_grad_manager = self
if not self._recording:
raise RuntimeError(
"no computation history. "
......@@ -29,8 +49,20 @@ class GradManager:
dys = [dys]
try:
self._grad(ys, dys)
for callback in self._after_backward_callback:
callback()
for p, grad in self._gradients.items():
if isinstance(grad, Future):
grad = grad.get()
param = self._param_dict[p]
if getattr(param, "grad", None) is None:
param.grad = grad
else:
param.grad += grad
finally:
self._grad = None
self._gradients = dict()
backwarding_grad_manager = cache
def record(self):
@contextmanager
......@@ -41,20 +73,24 @@ class GradManager:
try:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
for p in params:
for param_id in self._param_dict.keys():
param_wrapper = self._param_dict[param_id]
callbacks = self._call_back_dict[param_id]
def callback(param, grad, callbacks=callbacks, p=p):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
p.grad = ret
def callback(
param, grad, callbacks=callbacks, p=param_wrapper, gm=self
):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
gm._gradients[id(p)] = ret
grad.wrt(p, callback=callback)
grad.wrt(param_wrapper, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
self._gradients = dict()
return recorder()
......@@ -8,12 +8,33 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
from collections import defaultdict
from typing import Callable
from megengine.device import get_device_count
import numpy as np
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count
from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
from .util import Future
class FakeTensor(Future):
def device(self):
raise "Sorry, this tensor is not ready"
def numpy(self):
raise "Sorry, this tensor is not ready"
def shape(self):
raise "Sorry, this tensor is not ready"
def dtype(self):
raise "Sorry, this tensor is not ready"
def synchronized(func: Callable):
......@@ -52,14 +73,58 @@ def bcast_params_(params, group):
class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method
self._group = group
self._gm_set = set()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()
def _reset(self):
self._params = []
self._gradients_dict = dict()
self._futures_dict = dict()
self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int)
def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p.shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
for param, grad in zip(self._packing_list[dtype], reduced_grads):
self._gradients_dict[param] = grad
self._packing_list[dtype] = []
self._packing_size[dtype] = 0
def __call__(self, param, grad):
ret = all_reduce_sum(grad, self._group)
if self._reduce_method == "MEAN":
ret = ret / self._group.size
return ret
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._gm_set:
gm.register_after_backward_callback(self._flush)
self._gm_set.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._gradients_dict[param] = grad
self._packing_list[param.dtype].append(param)
self._packing_size[param.dtype] += (
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize
)
if self._packing_size[param.dtype] > self._param_pack_thd:
self._pack(param.dtype)
return self._futures_dict[param]
def _flush(self):
for dtype in self._packing_list.keys():
self._pack(dtype)
for param in self._params:
grad = self._gradients_dict[param]
grad = copy(grad, get_default_device())
self._futures_dict[param].set(grad)
self._reset()
make_allreduce_cb = AllreduceCallback
......@@ -16,25 +16,7 @@ from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
from ..core._imperative_rt.utils import create_mm_server
from .util import get_free_ports
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
from .util import Future, get_free_ports
class Methods:
......
......@@ -8,9 +8,28 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import socket
import threading
from typing import List
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
def get_free_ports(num: int) -> List[int]:
"""Get one or more free ports.
"""
......
......@@ -7,7 +7,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from . import distributed
from .elemwise import *
from .graph import add_update
from .loss import (
......@@ -27,6 +26,8 @@ from .quantized import conv_bias_activation
from .tensor import *
from .utils import accuracy, copy, zero_grad
from . import distributed # isort:skip
# delete namespace
# pylint: disable=undefined-variable
# del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined]
......@@ -8,9 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional import param_pack_concat, param_pack_split
from ..functional.distributed import all_reduce_sum
from ..tensor import Tensor
from .tensor import param_pack_concat, param_pack_split
def get_offsets(shapes):
......@@ -23,57 +23,12 @@ def get_offsets(shapes):
return offsets
def get_pack_list(param_group, param_pack_thd):
pack_list = dict()
shape_list = dict()
pack_sum = dict()
pack_ret, shape_ret = [], []
ignore_first = 8
ignore_last = 0
orders_len = len(param_group["orders"])
for i, idx in enumerate(param_group["orders"]):
param = param_group["params"][idx]
dtype = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
shape = param.shape
if ignore_first > 0:
ignore_first -= 1
pack_ret.append([idx])
shape_ret.append([shape])
continue
if dtype in pack_list.keys():
pack_list[dtype].append(idx)
shape_list[dtype].append(shape)
pack_sum[dtype] += int(np.prod(shape))
else:
pack_list[dtype] = [idx]
shape_list[dtype] = [shape]
pack_sum[dtype] = int(np.prod(shape))
if (
pack_sum[dtype] * dtype_size > param_pack_thd
or i + ignore_last > orders_len
):
pack_ret.append(pack_list[dtype])
shape_ret.append(shape_list[dtype])
pack_list[dtype] = []
shape_list[dtype] = []
pack_sum[dtype] = 0
for key in sorted(pack_list.keys()):
if len(pack_list[key]) > 0:
pack_ret.append(pack_list[key])
shape_ret.append(shape_list[key])
return pack_ret, shape_ret
def pack_allreduce_split(group, pack, shapes, reduce_method):
dist_group = group["dist_group"]
grads = [group["grads"][idx] for idx in pack]
def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(grads, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, dist_group, dist_group.comp_node)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= dist_group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
for i, grad in enumerate(grads):
group["grads"][pack[i]] = grad
return grads
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册