提交 b5ec9dfe 编写于 作者: M Megvii Engine Team

fix(mge/distributed): fix gather scatter reduce broadcast autodiff

GitOrigin-RevId: 1c2250a0795276b696c29d82b68c49eae4653078
上级 a49e202b
......@@ -8,9 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..tensor import Tensor
......@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
return result
def _save_output_for_autodiff(inp, out):
for g in _grad_manager_dict.values():
if g._is_attached_to(inp):
g._refkeeper.append(out)
def _bcast_has_grad(group, grad):
if group.rank == 0:
has_grad = grad is not None
get_client().bcast_val(has_grad, group.key, group.size)
else:
has_grad = get_client().bcast_val(None, group.key, group.size)
return has_grad
def _bcast_shape_dtype(group, inp):
if group.rank == 0:
# FIXME in some cases, shape is not available(output of condtake)
shape = inp._tuple_shape
dtype = np.dtype(inp.dtype).name
get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
else:
val = get_client().bcast_val(None, group.key, group.size)
shape = val["shape"]
dtype = val["dtype"]
return shape, dtype
def _bcast_tracer_state(group, inp):
if group.rank == 0:
tracer_keys = []
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
tracer_keys.append(n)
get_client().bcast_val(tracer_keys, group.key, group.size)
else:
tracer_keys = get_client().bcast_val(None, group.key, group.size)
for n in tracer_keys:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
def _dummy_input(shape, dtype, device=""):
if device == "":
device = get_default_device()
inp = Tensor(0, dtype=dtype, device=device)
if len(shape) > 0:
inp = inp._broadcast(shape)
return inp
class _ReduceSum(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
)
def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return broadcast(grad, self.group, self.in_device)
def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
......@@ -75,8 +148,30 @@ def reduce_sum(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.REDUCE_SUM
return collective_comm(inp, mode, group, device)
op = _ReduceSum(group, device)
(out,) = apply(op, inp)
if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)
class _Broadcast(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
)
def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return reduce_sum(grad, self.group, self.in_device)
def broadcast(
......@@ -89,8 +184,16 @@ def broadcast(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)
shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)
_bcast_tracer_state(group, inp)
op = _Broadcast(group, device)
(out,) = apply(op, inp)
return out
def all_gather(
......@@ -163,6 +266,23 @@ def all_reduce_min(
return collective_comm(inp, mode, group, device)
class _Gather(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.GATHER, self.group, self.out_device
)
def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return scatter(grad, self.group, self.in_device)
def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
......@@ -173,8 +293,31 @@ def gather(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.GATHER
return collective_comm(inp, mode, group, device)
op = _Gather(group, device)
(out,) = apply(op, inp)
if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)
class _Scatter(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
)
def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return gather(grad, self.group, self.in_device)
def scatter(
......@@ -187,8 +330,16 @@ def scatter(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.SCATTER
return collective_comm(inp, mode, group, device)
shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)
_bcast_tracer_state(group, inp)
op = _Scatter(group, device)
(out,) = apply(op, inp)
return out
def all_to_all(
......@@ -205,44 +356,46 @@ def all_to_all(
return collective_comm(inp, mode, group, device)
class _RemoteSend(PyOpBase):
class _SendRecvGroup:
def __init__(self, rank_from, rank_to):
self.key = "{}->{}".format(rank_from, rank_to)
self.rank_from = rank_from
self.rank_to = rank_to
self.size = 2
@property
def rank(self):
if get_rank() == self.rank_from:
return 0
else:
return 1
class _RemoteSend(Function):
def __init__(self, op: RemoteSend):
self.op = op
def _default_rule(self, data):
return apply(self.op, data)
def _grad_rule(self, data):
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward
def forward(self, data):
self.device = str(data.device)
(self.dummy,) = apply(self.op, data)
return self.dummy
def backward(self, grad):
assert grad is None
if get_client().check_is_grad(self.op.key):
return remote_recv(
self.op.rank_to,
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)
has_grad = get_client().bcast_val(None, self.op.key, 2)
if has_grad:
return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
class _RemoteRecv(PyOpBase):
class _RemoteRecv(Function):
def __init__(self, op: RemoteRecv):
self.op = op
def _default_rule(self, dummy):
def forward(self, dummy):
return apply(self.op, dummy)
def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward
def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None)
get_client().bcast_val(grad is not None, self.op.key, 2)
if grad is not None:
remote_send(grad, self.op.rank_from)
......@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param dest_rank: destination process rank.
"""
key = "{}->{}".format(get_rank(), dest_rank)
grad_keys = {}
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)
group = _SendRecvGroup(get_rank(), dest_rank)
_bcast_shape_dtype(group, inp)
_bcast_tracer_state(group, inp)
op = RemoteSend()
op.key = key
op.key = group.key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = get_backend()
(dummy,) = apply(_RemoteSend(op), inp)
(out,) = apply(_RemoteSend(op), inp)
for g in grad_keys.values():
g._refkeeper.append(dummy)
_save_output_for_autodiff(inp, out)
def remote_recv(
src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor:
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
"""
Receive a Tensor from a remote process.
:param src_rank: source process rank.
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
"""
key = "{}->{}".format(src_rank, get_rank())
group = _SendRecvGroup(src_rank, get_rank())
shape, dtype = _bcast_shape_dtype(group, None)
if device is None:
device = get_default_device()
# dummy input
if inp is None:
inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for n in tracer_set:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
inp = Tensor(0, device=device)
_bcast_tracer_state(group, inp)
_isscalar = False
if len(shape) == 0:
......@@ -308,7 +446,7 @@ def remote_recv(
_isscalar = True
op = RemoteRecv()
op.key = key
op.key = group.key
op.cn = device
op.shape = shape
op.dtype = dtype
......
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
import os
import queue
from ..core._imperative_rt.core2 import sync
......@@ -43,6 +44,8 @@ def _run_wrapped(
device=dev,
device_type=device_type,
)
# set NCCL_LAUNCH_MODE to avoid deadlock
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
if is_multimachine:
group_barrier()
ret = func(*args, **kwargs)
......
......@@ -253,6 +253,17 @@ class Client:
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)
def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val
def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册