提交 809d5056 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): enable pt shm allreduce

GitOrigin-RevId: 1dd5a02a512b210f2c75afd0062e4bfad1fcdddc
上级 02455941
......@@ -1018,6 +1018,7 @@ endif()
if(MGE_WITH_DISTRIBUTED)
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay)
endif()
......
......@@ -6,6 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from mprop import mproperty
from . import group
from .group import (
WORLD,
Group,
......@@ -19,7 +22,20 @@ from .group import (
init_process_group,
is_distributed,
new_group,
override_backend,
)
from .helper import bcast_list_, make_allreduce_cb, synchronized
from .launcher import launcher
from .server import Client, Server
@mproperty
def backend(mod):
assert group._sd, "please call init_process_group first"
return group._sd.backend
@backend.setter
def backend(mod, val):
assert group._sd, "please call init_process_group first"
group._sd.backend = val
......@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..device import get_default_device, what_is_xpu
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
from . import group
from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
__all__ = [
"reduce_sum",
......@@ -34,14 +35,30 @@ __all__ = [
]
_device2backend = {
"gpu": "nccl",
"cuda": "nccl",
"rocm": "rccl",
}
def _backend():
if group._sd.backend == "auto":
return _device2backend[what_is_xpu()]
else:
return group._sd.backend
def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions."""
assert isinstance(group, Group)
if group is None:
return inp
if device is None:
device = ""
addr, port = get_mm_server_addr()
op = CollectiveComm(
key=group.key,
key=group.key + _backend(),
nr_devices=group.size,
rank=group.rank,
is_root=(group.rank == 0),
......@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device):
port=port,
mode=mode,
dtype=inp.dtype,
backend=get_backend(),
backend=_backend(),
comp_node=device,
)
(result,) = apply(op, inp)
......@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp):
g._refkeeper.append(inp)
def _dummy_input(shape, dtype, device=""):
if device == "":
def _dummy_input(shape, dtype, device=None):
if device is None:
device = get_default_device()
inp = Tensor(0, dtype=dtype, device=device)
if len(shape) > 0:
......@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""):
class _ReduceSum(Function):
def __init__(self, group=WORLD, device=""):
def __init__(self, group=WORLD, device=None):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
)
def backward(self, grad):
......@@ -139,7 +156,7 @@ class _ReduceSum(Function):
def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create reduce_sum operator for collective communication.
......@@ -158,14 +175,14 @@ def reduce_sum(
class _Broadcast(Function):
def __init__(self, group=WORLD, device=""):
def __init__(self, group=WORLD, device=None):
self.group = group
self.out_device = device
def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
)
def backward(self, grad):
......@@ -175,7 +192,7 @@ class _Broadcast(Function):
def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create broadcast operator for collective communication.
......@@ -197,14 +214,14 @@ def broadcast(
def _bcast_param(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
) -> Tensor:
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)
def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create all_gather operator for collective communication.
......@@ -218,7 +235,7 @@ def all_gather(
def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create reduce_scatter_sum operator for collective communication.
......@@ -232,7 +249,7 @@ def reduce_scatter_sum(
def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create all_reduce_sum operator for collective communication.
......@@ -246,7 +263,7 @@ def all_reduce_sum(
def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create all_reduce_max operator for collective communication.
......@@ -260,7 +277,7 @@ def all_reduce_max(
def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create all_reduce_min operator for collective communication.
......@@ -274,7 +291,7 @@ def all_reduce_min(
class _Gather(Function):
def __init__(self, group=WORLD, device=""):
def __init__(self, group=WORLD, device=None):
self.group = group
self.out_device = device
......@@ -291,7 +308,7 @@ class _Gather(Function):
def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create gather operator for collective communication.
......@@ -311,7 +328,7 @@ def gather(
class _Scatter(Function):
def __init__(self, group=WORLD, device=""):
def __init__(self, group=WORLD, device=None):
self.group = group
self.out_device = device
......@@ -328,7 +345,7 @@ class _Scatter(Function):
def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create scatter operator for collective communication.
......@@ -350,7 +367,7 @@ def scatter(
def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
Create all_to_all operator for collective communication.
......@@ -407,7 +424,7 @@ class _RemoteRecv(Function):
remote_send(grad, self.op.rank_from)
def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_send(inp: Tensor, dest_rank: int):
"""
Send a Tensor to a remote process.
......@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
op.key = group.key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = get_backend()
op.backend = _backend()
(out,) = apply(_RemoteSend(op), inp)
_save_output_for_autodiff(inp, out)
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
"""
Receive a Tensor from a remote process.
......@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank
op.backend = get_backend()
op.backend = _backend()
(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
......
......@@ -7,8 +7,11 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from contextlib import contextmanager
from typing import List, Optional, Tuple
from mprop import mproperty
from ..device import set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server
......@@ -26,6 +29,7 @@ class StaticData:
backend = None
next_stream = None
device_type = None
machine_ranks = None
_sd = None
......@@ -55,6 +59,7 @@ class Group:
self.proc_ranks = proc_ranks
self.stream = _sd.next_stream
_sd.next_stream += 1
self.is_single_machine_cache = None
def check(self, proc_ranks):
assert _sd is not None, "please call init_process_group first"
......@@ -83,17 +88,23 @@ class Group:
assert len(self.proc_ranks) > 0, "invalid group"
return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
WORLD = Group([])
@property
def is_single_machine(self):
if self.is_single_machine_cache is not None:
return self.is_single_machine_cache
assert _sd is not None, "please call init_process_group first"
for rank in self.proc_ranks:
if rank not in _sd.machine_ranks:
self.is_single_machine_cache = False
return False
self.is_single_machine_cache = True
return True
_device2backend = {
"gpu": "nccl",
"cuda": "nccl",
"rocm": "rccl",
}
WORLD = Group([])
_backends = {"nccl", "rccl", "ucx"}
_devices = {"gpu", "cuda", "rocm"}
_backends = {"nccl", "rccl", "ucx", "auto"}
def init_process_group(
......@@ -102,7 +113,7 @@ def init_process_group(
world_size: int,
rank: int,
device: int,
backend: Optional[str] = None,
backend: Optional[str] = "auto",
device_type: str = "xpu",
) -> None:
"""
......@@ -113,10 +124,9 @@ def init_process_group(
:param world_size: total number of processes participating in the job.
:param rank: rank of the current process.
:param device: the GPU device id to bind this process to.
:param backend: communicator backend, currently support 'nccl' and 'ucx'.
:param backend: communicator backend, currently support 'nccl' and 'shm'.
"""
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
backend = _device2backend[physical_device_type] if backend is None else backend
if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip)))
if not isinstance(port, int):
......@@ -131,7 +141,7 @@ def init_process_group(
raise ValueError(
"backend should be one of {} but got {}".format(_backends, backend)
)
if physical_device_type not in _device2backend:
if physical_device_type not in _devices:
raise ValueError(
"{} is not a valid distributed device type".format(device_type)
)
......@@ -161,6 +171,30 @@ def init_process_group(
seed(int(time.time()) + rank)
def _set_machine_ranks(ranks) -> None:
global _sd
assert _sd is not None
_sd.machine_ranks = ranks
@contextmanager
def override_backend(new_backend: str):
"""
Override distributed backend
:param new_backend: communicator backend set in this context.
"""
global _sd
assert _sd, "please call init_process_group first"
old_backend = _sd.backend
_sd.backend = new_backend
try:
yield
finally:
_sd.backend = old_backend
def is_distributed() -> bool:
"""Return True if the distributed process group has been initialized."""
return _sd is not None
......
......@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy
from ..tensor import Tensor
from ..utils.future import Future
from . import group as _group
from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed
from .group import WORLD, Group, group_barrier, is_distributed, override_backend
def param_pack_split(inp: Tensor, offsets: list, shapes: list):
......@@ -118,10 +119,30 @@ def get_offsets(shapes):
return offsets
_enable_p2p_cache = None
def _check_enable_p2p():
global _enable_p2p_cache
if _enable_p2p_cache is not None:
return _enable_p2p_cache
cmd = ["nvidia-smi", "topo", "-p2p", "w"]
import subprocess
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
if output.count(b"OK") > 1:
_enable_p2p_cache = True
return True
else:
_enable_p2p_cache = False
return False
def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= group.size
......@@ -207,9 +228,10 @@ class AllreduceCallback:
:param reduce_method: the method to reduce gradiants.
:param group: communication group.
:param backend: override distributed backend in allreduce
"""
def __init__(self, reduce_method: str, group: Group = WORLD):
def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method
......@@ -217,6 +239,15 @@ class AllreduceCallback:
self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()
if backend is None:
assert _group._sd, "please call init_process_group first"
backend = _group._sd.backend
if backend == "auto":
if group.is_single_machine and not _check_enable_p2p():
backend = "shm"
else:
backend = "nccl"
self._backend = backend
def _reset(self):
self._params = []
......@@ -231,9 +262,10 @@ class AllreduceCallback:
return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p._tuple_shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
with override_backend(self._backend):
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
for param, grad in zip(self._packing_list[dtype], reduced_grads):
self._gradients_dict[param] = grad
self._packing_list[dtype] = []
......
......@@ -14,7 +14,7 @@ import queue
from .. import _exit
from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger
from .group import group_barrier, init_process_group
from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized, get_device_count_by_fork
from .server import Client, Server
......@@ -34,7 +34,9 @@ def _run_wrapped(
device_type,
args,
kwargs,
backend,
queue: mp.Queue,
machine_ranks: list,
):
"""Init distributed process group and run wrapped function."""
_check_device_initialized(device_type)
......@@ -44,10 +46,12 @@ def _run_wrapped(
world_size=world_size,
rank=rank,
device=dev,
backend=backend,
device_type=device_type,
)
# set NCCL_LAUNCH_MODE to avoid deadlock
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
_set_machine_ranks(machine_ranks)
if is_multimachine:
group_barrier()
ret = func(*args, **kwargs)
......@@ -67,6 +71,7 @@ class launcher:
:param rank_start: start number for rank.
:param master_ip: ip address for master node (where the rank 0 is).
:param port: server port for distributed server.
:param backend: set default collective communication backend.
"""
def __new__(cls, *args, **kwargs):
......@@ -83,6 +88,7 @@ class launcher:
master_ip="localhost",
port=0,
device_type="xpu",
backend="auto",
):
self.func = func
self.n_gpus = (
......@@ -93,6 +99,7 @@ class launcher:
self.master_ip = master_ip
self.port = port
self.device_type = device_type
self.backend = backend
# master node create server
if self.rank_start == 0:
self.server = Server(self.port)
......@@ -104,6 +111,7 @@ class launcher:
procs = []
queue = mp.Queue(self.n_gpus)
results = [None] * self.n_gpus
machine_ranks = [i + self.rank_start for i in range(self.n_gpus)]
for dev in range(self.n_gpus):
p = mp.Process(
target=_run_wrapped,
......@@ -118,7 +126,9 @@ class launcher:
self.device_type,
args,
kwargs,
self.backend,
queue,
machine_ranks,
),
)
p.start()
......
......@@ -11,6 +11,7 @@
#include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"
#include "megray/common.h"
using namespace mgb;
using namespace opr;
......@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") {
return MegRay::MEGRAY_UCX;
} else if (backend == "shm") {
return MegRay::MEGRAY_SHM;
} else {
mgb_throw(MegBrainError, "back CollectiveComm backend");
}
......@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
if (rank == root) {
char* c = MegRay::get_host_ip();
master_ip = std::string(c);
delete c;
delete [] c;
port = MegRay::get_free_port();
auto ret = MegRay::create_server(size, port);
mgb_assert(ret == MegRay::Status::MEGRAY_OK);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册