From b5ec9dfef626f9608b14876b442d83e75ff32db2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 15 Apr 2021 19:28:02 +0800 Subject: [PATCH] fix(mge/distributed): fix gather scatter reduce broadcast autodiff GitOrigin-RevId: 1c2250a0795276b696c29d82b68c49eae4653078 --- .../megengine/distributed/functional.py | 262 +++++++++++++----- .../python/megengine/distributed/launcher.py | 3 + .../python/megengine/distributed/server.py | 11 + 3 files changed, 214 insertions(+), 62 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index f5c58db99..453d3ba55 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -8,9 +8,11 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Optional, Tuple +import numpy as np + from ..core._imperative_rt.core2 import apply -from ..core.autodiff.grad import _grad_manager_dict -from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend +from ..core.autodiff.grad import Function, _grad_manager_dict +from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.tensor.utils import isscalar, setscalar from ..device import get_default_device from ..tensor import Tensor @@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device): return result +def _save_output_for_autodiff(inp, out): + for g in _grad_manager_dict.values(): + if g._is_attached_to(inp): + g._refkeeper.append(out) + + +def _bcast_has_grad(group, grad): + if group.rank == 0: + has_grad = grad is not None + get_client().bcast_val(has_grad, group.key, group.size) + else: + has_grad = get_client().bcast_val(None, group.key, group.size) + return has_grad + + +def _bcast_shape_dtype(group, inp): + if group.rank == 0: + # FIXME in some cases, shape is not available(output of condtake) + shape = inp._tuple_shape + dtype = np.dtype(inp.dtype).name + get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size) + else: + val = get_client().bcast_val(None, group.key, group.size) + shape = val["shape"] + dtype = val["dtype"] + + return shape, dtype + + +def _bcast_tracer_state(group, inp): + if group.rank == 0: + tracer_keys = [] + for n, g in _grad_manager_dict.items(): + if g._is_attached_to(inp): + tracer_keys.append(n) + get_client().bcast_val(tracer_keys, group.key, group.size) + else: + tracer_keys = get_client().bcast_val(None, group.key, group.size) + for n in tracer_keys: + g = _grad_manager_dict.get(n) + if g is not None: + g.wrt(inp) + g._refkeeper.append(inp) + + +def _dummy_input(shape, dtype, device=""): + if device == "": + device = get_default_device() + inp = Tensor(0, dtype=dtype, device=device) + if len(shape) > 0: + inp = inp._broadcast(shape) + return inp + + +class _ReduceSum(Function): + def __init__(self, group=WORLD, device=""): + self.group = group + self.out_device = device + + def forward(self, data): + self.in_device = str(data.device) + return collective_comm( + data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device + ) + + def backward(self, grad): + has_grad = _bcast_has_grad(self.group, grad) + if has_grad: + return broadcast(grad, self.group, self.in_device) + + def reduce_sum( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: @@ -75,8 +148,30 @@ def reduce_sum( :param group: communication group. :param device: execution device. """ - mode = CollectiveComm.Mode.REDUCE_SUM - return collective_comm(inp, mode, group, device) + op = _ReduceSum(group, device) + (out,) = apply(op, inp) + + if group.rank == 0: + return out + else: + _save_output_for_autodiff(inp, out) + + +class _Broadcast(Function): + def __init__(self, group=WORLD, device=""): + self.group = group + self.out_device = device + + def forward(self, data): + self.in_device = str(data.device) + return collective_comm( + data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device + ) + + def backward(self, grad): + # TODO backward with a part of grad + if grad is not None: + return reduce_sum(grad, self.group, self.in_device) def broadcast( @@ -89,8 +184,16 @@ def broadcast( :param group: communication group. :param device: execution device. """ - mode = CollectiveComm.Mode.BROADCAST - return collective_comm(inp, mode, group, device) + shape, dtype = _bcast_shape_dtype(group, inp) + if group.rank != 0: + # dummy input to infer shape + inp = _dummy_input(shape, dtype, device) + + _bcast_tracer_state(group, inp) + + op = _Broadcast(group, device) + (out,) = apply(op, inp) + return out def all_gather( @@ -163,6 +266,23 @@ def all_reduce_min( return collective_comm(inp, mode, group, device) +class _Gather(Function): + def __init__(self, group=WORLD, device=""): + self.group = group + self.out_device = device + + def forward(self, data): + self.in_device = str(data.device) + return collective_comm( + data, CollectiveComm.Mode.GATHER, self.group, self.out_device + ) + + def backward(self, grad): + has_grad = _bcast_has_grad(self.group, grad) + if has_grad: + return scatter(grad, self.group, self.in_device) + + def gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: @@ -173,8 +293,31 @@ def gather( :param group: communication group. :param device: execution device. """ - mode = CollectiveComm.Mode.GATHER - return collective_comm(inp, mode, group, device) + + op = _Gather(group, device) + (out,) = apply(op, inp) + + if group.rank == 0: + return out + else: + _save_output_for_autodiff(inp, out) + + +class _Scatter(Function): + def __init__(self, group=WORLD, device=""): + self.group = group + self.out_device = device + + def forward(self, data): + self.in_device = str(data.device) + return collective_comm( + data, CollectiveComm.Mode.SCATTER, self.group, self.out_device + ) + + def backward(self, grad): + # TODO backward with a part of grad + if grad is not None: + return gather(grad, self.group, self.in_device) def scatter( @@ -187,8 +330,16 @@ def scatter( :param group: communication group. :param device: execution device. """ - mode = CollectiveComm.Mode.SCATTER - return collective_comm(inp, mode, group, device) + shape, dtype = _bcast_shape_dtype(group, inp) + if group.rank != 0: + # dummy input to infer shape + inp = _dummy_input(shape, dtype, device) + + _bcast_tracer_state(group, inp) + + op = _Scatter(group, device) + (out,) = apply(op, inp) + return out def all_to_all( @@ -205,44 +356,46 @@ def all_to_all( return collective_comm(inp, mode, group, device) -class _RemoteSend(PyOpBase): +class _SendRecvGroup: + def __init__(self, rank_from, rank_to): + self.key = "{}->{}".format(rank_from, rank_to) + self.rank_from = rank_from + self.rank_to = rank_to + self.size = 2 + + @property + def rank(self): + if get_rank() == self.rank_from: + return 0 + else: + return 1 + + +class _RemoteSend(Function): def __init__(self, op: RemoteSend): self.op = op - def _default_rule(self, data): - return apply(self.op, data) - - def _grad_rule(self, data): - self.dtype = data.dtype - self.shape = data.shape - self.device = data.device - (self.dummy,) = self._default_rule(data) - return self.dummy, self.backward + def forward(self, data): + self.device = str(data.device) + (self.dummy,) = apply(self.op, data) + return self.dummy def backward(self, grad): assert grad is None - if get_client().check_is_grad(self.op.key): - return remote_recv( - self.op.rank_to, - self.shape, - self.dtype, - device=str(self.device), - inp=self.dummy, - ) + has_grad = get_client().bcast_val(None, self.op.key, 2) + if has_grad: + return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,) -class _RemoteRecv(PyOpBase): +class _RemoteRecv(Function): def __init__(self, op: RemoteRecv): self.op = op - def _default_rule(self, dummy): + def forward(self, dummy): return apply(self.op, dummy) - def _grad_rule(self, dummy): - return self._default_rule(dummy), self.backward - def backward(self, grad): - get_client().set_is_grad(self.op.key, grad is not None) + get_client().bcast_val(grad is not None, self.op.key, 2) if grad is not None: remote_send(grad, self.op.rank_from) @@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: :param inp: tensor to send. :param dest_rank: destination process rank. """ - key = "{}->{}".format(get_rank(), dest_rank) - grad_keys = {} - for n, g in _grad_manager_dict.items(): - if g._is_attached_to(inp): - grad_keys[n] = g - get_client().set_remote_tracer(key, grad_keys) + group = _SendRecvGroup(get_rank(), dest_rank) + _bcast_shape_dtype(group, inp) + + _bcast_tracer_state(group, inp) op = RemoteSend() - op.key = key + op.key = group.key op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank op.backend = get_backend() - (dummy,) = apply(_RemoteSend(op), inp) + (out,) = apply(_RemoteSend(op), inp) - for g in grad_keys.values(): - g._refkeeper.append(dummy) + _save_output_for_autodiff(inp, out) -def remote_recv( - src_rank: int, - shape: Tuple[int], - dtype: type, - device: Optional[str] = None, - inp=None, -) -> Tensor: +def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: """ Receive a Tensor from a remote process. :param src_rank: source process rank. - :param shape: the shape of the tensor to receive. - :param dtype: the data type of the tensor to receive. :param device: the device to place the received tensor. :param inp: dummy input to determine recved tensor type """ - key = "{}->{}".format(src_rank, get_rank()) + group = _SendRecvGroup(src_rank, get_rank()) + shape, dtype = _bcast_shape_dtype(group, None) if device is None: device = get_default_device() # dummy input if inp is None: - inp = Tensor([0], device=device) - tracer_set = get_client().check_remote_tracer(key) - for n in tracer_set: - g = _grad_manager_dict.get(n) - if g is not None: - g.wrt(inp) - g._refkeeper.append(inp) + inp = Tensor(0, device=device) + _bcast_tracer_state(group, inp) _isscalar = False if len(shape) == 0: @@ -308,7 +446,7 @@ def remote_recv( _isscalar = True op = RemoteRecv() - op.key = key + op.key = group.key op.cn = device op.shape = shape op.dtype = dtype diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 950feb0ac..1e2a3dff9 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import multiprocessing as mp +import os import queue from ..core._imperative_rt.core2 import sync @@ -43,6 +44,8 @@ def _run_wrapped( device=dev, device_type=device_type, ) + # set NCCL_LAUNCH_MODE to avoid deadlock + os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" if is_multimachine: group_barrier() ret = func(*args, **kwargs) diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index 4e83c19a7..27f23dfd1 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -253,6 +253,17 @@ class Client: """Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) + def bcast_val(self, val, key, size): + if val is not None: + self.user_set(key + "_sync", val) + self.group_barrier(key, size) + self.group_barrier(key, size) + else: + self.group_barrier(key, size) + val = self.user_get(key + "_sync") + self.group_barrier(key, size) + return val + def main(port=0, verbose=True): mm_server_port = create_mm_server("0.0.0.0", 0) -- GitLab