提交 b5ec9dfe 编写于 作者: M Megvii Engine Team

fix(mge/distributed): fix gather scatter reduce broadcast autodiff

GitOrigin-RevId: 1c2250a0795276b696c29d82b68c49eae4653078
上级 a49e202b
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
...@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device): ...@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
return result return result
def _save_output_for_autodiff(inp, out):
for g in _grad_manager_dict.values():
if g._is_attached_to(inp):
g._refkeeper.append(out)
def _bcast_has_grad(group, grad):
if group.rank == 0:
has_grad = grad is not None
get_client().bcast_val(has_grad, group.key, group.size)
else:
has_grad = get_client().bcast_val(None, group.key, group.size)
return has_grad
def _bcast_shape_dtype(group, inp):
if group.rank == 0:
# FIXME in some cases, shape is not available(output of condtake)
shape = inp._tuple_shape
dtype = np.dtype(inp.dtype).name
get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
else:
val = get_client().bcast_val(None, group.key, group.size)
shape = val["shape"]
dtype = val["dtype"]
return shape, dtype
def _bcast_tracer_state(group, inp):
if group.rank == 0:
tracer_keys = []
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
tracer_keys.append(n)
get_client().bcast_val(tracer_keys, group.key, group.size)
else:
tracer_keys = get_client().bcast_val(None, group.key, group.size)
for n in tracer_keys:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
def _dummy_input(shape, dtype, device=""):
if device == "":
device = get_default_device()
inp = Tensor(0, dtype=dtype, device=device)
if len(shape) > 0:
inp = inp._broadcast(shape)
return inp
class _ReduceSum(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
)
def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return broadcast(grad, self.group, self.in_device)
def reduce_sum( def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
...@@ -75,8 +148,30 @@ def reduce_sum( ...@@ -75,8 +148,30 @@ def reduce_sum(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveComm.Mode.REDUCE_SUM op = _ReduceSum(group, device)
return collective_comm(inp, mode, group, device) (out,) = apply(op, inp)
if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)
class _Broadcast(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
)
def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return reduce_sum(grad, self.group, self.in_device)
def broadcast( def broadcast(
...@@ -89,8 +184,16 @@ def broadcast( ...@@ -89,8 +184,16 @@ def broadcast(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveComm.Mode.BROADCAST shape, dtype = _bcast_shape_dtype(group, inp)
return collective_comm(inp, mode, group, device) if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)
_bcast_tracer_state(group, inp)
op = _Broadcast(group, device)
(out,) = apply(op, inp)
return out
def all_gather( def all_gather(
...@@ -163,6 +266,23 @@ def all_reduce_min( ...@@ -163,6 +266,23 @@ def all_reduce_min(
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
class _Gather(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.GATHER, self.group, self.out_device
)
def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return scatter(grad, self.group, self.in_device)
def gather( def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:
...@@ -173,8 +293,31 @@ def gather( ...@@ -173,8 +293,31 @@ def gather(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveComm.Mode.GATHER
return collective_comm(inp, mode, group, device) op = _Gather(group, device)
(out,) = apply(op, inp)
if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)
class _Scatter(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
)
def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return gather(grad, self.group, self.in_device)
def scatter( def scatter(
...@@ -187,8 +330,16 @@ def scatter( ...@@ -187,8 +330,16 @@ def scatter(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveComm.Mode.SCATTER shape, dtype = _bcast_shape_dtype(group, inp)
return collective_comm(inp, mode, group, device) if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)
_bcast_tracer_state(group, inp)
op = _Scatter(group, device)
(out,) = apply(op, inp)
return out
def all_to_all( def all_to_all(
...@@ -205,44 +356,46 @@ def all_to_all( ...@@ -205,44 +356,46 @@ def all_to_all(
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
class _RemoteSend(PyOpBase): class _SendRecvGroup:
def __init__(self, rank_from, rank_to):
self.key = "{}->{}".format(rank_from, rank_to)
self.rank_from = rank_from
self.rank_to = rank_to
self.size = 2
@property
def rank(self):
if get_rank() == self.rank_from:
return 0
else:
return 1
class _RemoteSend(Function):
def __init__(self, op: RemoteSend): def __init__(self, op: RemoteSend):
self.op = op self.op = op
def _default_rule(self, data): def forward(self, data):
return apply(self.op, data) self.device = str(data.device)
(self.dummy,) = apply(self.op, data)
def _grad_rule(self, data): return self.dummy
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward
def backward(self, grad): def backward(self, grad):
assert grad is None assert grad is None
if get_client().check_is_grad(self.op.key): has_grad = get_client().bcast_val(None, self.op.key, 2)
return remote_recv( if has_grad:
self.op.rank_to, return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)
class _RemoteRecv(PyOpBase): class _RemoteRecv(Function):
def __init__(self, op: RemoteRecv): def __init__(self, op: RemoteRecv):
self.op = op self.op = op
def _default_rule(self, dummy): def forward(self, dummy):
return apply(self.op, dummy) return apply(self.op, dummy)
def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward
def backward(self, grad): def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None) get_client().bcast_val(grad is not None, self.op.key, 2)
if grad is not None: if grad is not None:
remote_send(grad, self.op.rank_from) remote_send(grad, self.op.rank_from)
...@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send. :param inp: tensor to send.
:param dest_rank: destination process rank. :param dest_rank: destination process rank.
""" """
key = "{}->{}".format(get_rank(), dest_rank) group = _SendRecvGroup(get_rank(), dest_rank)
grad_keys = {} _bcast_shape_dtype(group, inp)
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp): _bcast_tracer_state(group, inp)
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)
op = RemoteSend() op = RemoteSend()
op.key = key op.key = group.key
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank op.rank_to = dest_rank
op.backend = get_backend() op.backend = get_backend()
(dummy,) = apply(_RemoteSend(op), inp) (out,) = apply(_RemoteSend(op), inp)
for g in grad_keys.values(): _save_output_for_autodiff(inp, out)
g._refkeeper.append(dummy)
def remote_recv( def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor:
""" """
Receive a Tensor from a remote process. Receive a Tensor from a remote process.
:param src_rank: source process rank. :param src_rank: source process rank.
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor. :param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type :param inp: dummy input to determine recved tensor type
""" """
key = "{}->{}".format(src_rank, get_rank()) group = _SendRecvGroup(src_rank, get_rank())
shape, dtype = _bcast_shape_dtype(group, None)
if device is None: if device is None:
device = get_default_device() device = get_default_device()
# dummy input # dummy input
if inp is None: if inp is None:
inp = Tensor([0], device=device) inp = Tensor(0, device=device)
tracer_set = get_client().check_remote_tracer(key) _bcast_tracer_state(group, inp)
for n in tracer_set:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
_isscalar = False _isscalar = False
if len(shape) == 0: if len(shape) == 0:
...@@ -308,7 +446,7 @@ def remote_recv( ...@@ -308,7 +446,7 @@ def remote_recv(
_isscalar = True _isscalar = True
op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = group.key
op.cn = device op.cn = device
op.shape = shape op.shape = shape
op.dtype = dtype op.dtype = dtype
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import multiprocessing as mp import multiprocessing as mp
import os
import queue import queue
from ..core._imperative_rt.core2 import sync from ..core._imperative_rt.core2 import sync
...@@ -43,6 +44,8 @@ def _run_wrapped( ...@@ -43,6 +44,8 @@ def _run_wrapped(
device=dev, device=dev,
device_type=device_type, device_type=device_type,
) )
# set NCCL_LAUNCH_MODE to avoid deadlock
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
......
...@@ -253,6 +253,17 @@ class Client: ...@@ -253,6 +253,17 @@ class Client:
"""Get user defined key-value pairs across processes.""" """Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)
def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
def main(port=0, verbose=True): def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0) mm_server_port = create_mm_server("0.0.0.0", 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册