diff --git a/CMakeLists.txt b/CMakeLists.txt index a580b5f486a2163510fc29af62abb8070116b1e8..3af836e381c80e1787e9901e8828879e0eb48480 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1018,6 +1018,7 @@ endif() if(MGE_WITH_DISTRIBUTED) set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) + set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) endif() diff --git a/imperative/python/megengine/distributed/__init__.py b/imperative/python/megengine/distributed/__init__.py index b6e8359325c1ca74d4d796bbcba16e546e850e34..19e3b9a36192290dffe73f3eba4e565354abf6e1 100644 --- a/imperative/python/megengine/distributed/__init__.py +++ b/imperative/python/megengine/distributed/__init__.py @@ -6,6 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from mprop import mproperty + +from . import group from .group import ( WORLD, Group, @@ -19,7 +22,20 @@ from .group import ( init_process_group, is_distributed, new_group, + override_backend, ) from .helper import bcast_list_, make_allreduce_cb, synchronized from .launcher import launcher from .server import Client, Server + + +@mproperty +def backend(mod): + assert group._sd, "please call init_process_group first" + return group._sd.backend + + +@backend.setter +def backend(mod, val): + assert group._sd, "please call init_process_group first" + group._sd.backend = val diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index d832ae86c1eb711f748fa01d433ec3a465b2217f..9171b64e06f6d2568709b888fc805e0aaf9ed0fc 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply from ..core.autodiff.grad import Function, _grad_manager_dict from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.tensor.utils import isscalar, setscalar -from ..device import get_default_device +from ..device import get_default_device, what_is_xpu from ..tensor import Tensor -from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank +from . import group +from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank __all__ = [ "reduce_sum", @@ -34,14 +35,30 @@ __all__ = [ ] +_device2backend = { + "gpu": "nccl", + "cuda": "nccl", + "rocm": "rccl", +} + + +def _backend(): + if group._sd.backend == "auto": + return _device2backend[what_is_xpu()] + else: + return group._sd.backend + + def collective_comm(inp, mode, group, device): """Helper function for applying collective communication functions.""" assert isinstance(group, Group) if group is None: return inp + if device is None: + device = "" addr, port = get_mm_server_addr() op = CollectiveComm( - key=group.key, + key=group.key + _backend(), nr_devices=group.size, rank=group.rank, is_root=(group.rank == 0), @@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): port=port, mode=mode, dtype=inp.dtype, - backend=get_backend(), + backend=_backend(), comp_node=device, ) (result,) = apply(op, inp) @@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): g._refkeeper.append(inp) -def _dummy_input(shape, dtype, device=""): - if device == "": +def _dummy_input(shape, dtype, device=None): + if device is None: device = get_default_device() inp = Tensor(0, dtype=dtype, device=device) if len(shape) > 0: @@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): class _ReduceSum(Function): - def __init__(self, group=WORLD, device=""): + def __init__(self, group=WORLD, device=None): self.group = group self.out_device = device def forward(self, data): self.in_device = str(data.device) return collective_comm( - data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device + data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device, ) def backward(self, grad): @@ -139,7 +156,7 @@ class _ReduceSum(Function): def reduce_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create reduce_sum operator for collective communication. @@ -158,14 +175,14 @@ def reduce_sum( class _Broadcast(Function): - def __init__(self, group=WORLD, device=""): + def __init__(self, group=WORLD, device=None): self.group = group self.out_device = device def forward(self, data): self.in_device = str(data.device) return collective_comm( - data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device + data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device, ) def backward(self, grad): @@ -175,7 +192,7 @@ class _Broadcast(Function): def broadcast( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create broadcast operator for collective communication. @@ -197,14 +214,14 @@ def broadcast( def _bcast_param( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None ) -> Tensor: mode = CollectiveComm.Mode.BROADCAST return collective_comm(inp, mode, group, device) def all_gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create all_gather operator for collective communication. @@ -218,7 +235,7 @@ def all_gather( def reduce_scatter_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create reduce_scatter_sum operator for collective communication. @@ -232,7 +249,7 @@ def reduce_scatter_sum( def all_reduce_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create all_reduce_sum operator for collective communication. @@ -246,7 +263,7 @@ def all_reduce_sum( def all_reduce_max( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create all_reduce_max operator for collective communication. @@ -260,7 +277,7 @@ def all_reduce_max( def all_reduce_min( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create all_reduce_min operator for collective communication. @@ -274,7 +291,7 @@ def all_reduce_min( class _Gather(Function): - def __init__(self, group=WORLD, device=""): + def __init__(self, group=WORLD, device=None): self.group = group self.out_device = device @@ -291,7 +308,7 @@ class _Gather(Function): def gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create gather operator for collective communication. @@ -311,7 +328,7 @@ def gather( class _Scatter(Function): - def __init__(self, group=WORLD, device=""): + def __init__(self, group=WORLD, device=None): self.group = group self.out_device = device @@ -328,7 +345,7 @@ class _Scatter(Function): def scatter( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create scatter operator for collective communication. @@ -350,7 +367,7 @@ def scatter( def all_to_all( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: """ Create all_to_all operator for collective communication. @@ -407,7 +424,7 @@ class _RemoteRecv(Function): remote_send(grad, self.op.rank_from) -def remote_send(inp: Tensor, dest_rank: int) -> Tensor: +def remote_send(inp: Tensor, dest_rank: int): """ Send a Tensor to a remote process. @@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: op.key = group.key op.addr, op.port = get_mm_server_addr() op.rank_to = dest_rank - op.backend = get_backend() + op.backend = _backend() (out,) = apply(_RemoteSend(op), inp) _save_output_for_autodiff(inp, out) -def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: +def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: """ Receive a Tensor from a remote process. @@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso op.dtype = dtype op.addr, op.port = get_mm_server_addr() op.rank_from = src_rank - op.backend = get_backend() + op.backend = _backend() (ret,) = apply(_RemoteRecv(op), inp) if _isscalar: diff --git a/imperative/python/megengine/distributed/group.py b/imperative/python/megengine/distributed/group.py index c00933e99f2ecb4f3664fda8280ffd75ee6e29bd..fb0f3f11317d4de28debe77568ad0913e460a802 100644 --- a/imperative/python/megengine/distributed/group.py +++ b/imperative/python/megengine/distributed/group.py @@ -7,8 +7,11 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import time +from contextlib import contextmanager from typing import List, Optional, Tuple +from mprop import mproperty + from ..device import set_default_device, what_is_xpu from ..random import seed from .server import Client, Server @@ -26,6 +29,7 @@ class StaticData: backend = None next_stream = None device_type = None + machine_ranks = None _sd = None @@ -55,6 +59,7 @@ class Group: self.proc_ranks = proc_ranks self.stream = _sd.next_stream _sd.next_stream += 1 + self.is_single_machine_cache = None def check(self, proc_ranks): assert _sd is not None, "please call init_process_group first" @@ -83,17 +88,23 @@ class Group: assert len(self.proc_ranks) > 0, "invalid group" return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) - -WORLD = Group([]) + @property + def is_single_machine(self): + if self.is_single_machine_cache is not None: + return self.is_single_machine_cache + assert _sd is not None, "please call init_process_group first" + for rank in self.proc_ranks: + if rank not in _sd.machine_ranks: + self.is_single_machine_cache = False + return False + self.is_single_machine_cache = True + return True -_device2backend = { - "gpu": "nccl", - "cuda": "nccl", - "rocm": "rccl", -} +WORLD = Group([]) -_backends = {"nccl", "rccl", "ucx"} +_devices = {"gpu", "cuda", "rocm"} +_backends = {"nccl", "rccl", "ucx", "auto"} def init_process_group( @@ -102,7 +113,7 @@ def init_process_group( world_size: int, rank: int, device: int, - backend: Optional[str] = None, + backend: Optional[str] = "auto", device_type: str = "xpu", ) -> None: """ @@ -113,10 +124,9 @@ def init_process_group( :param world_size: total number of processes participating in the job. :param rank: rank of the current process. :param device: the GPU device id to bind this process to. - :param backend: communicator backend, currently support 'nccl' and 'ucx'. + :param backend: communicator backend, currently support 'nccl' and 'shm'. """ physical_device_type = what_is_xpu() if device_type == "xpu" else device_type - backend = _device2backend[physical_device_type] if backend is None else backend if not isinstance(master_ip, str): raise TypeError("Expect type str but got {}".format(type(master_ip))) if not isinstance(port, int): @@ -131,7 +141,7 @@ def init_process_group( raise ValueError( "backend should be one of {} but got {}".format(_backends, backend) ) - if physical_device_type not in _device2backend: + if physical_device_type not in _devices: raise ValueError( "{} is not a valid distributed device type".format(device_type) ) @@ -161,6 +171,30 @@ def init_process_group( seed(int(time.time()) + rank) +def _set_machine_ranks(ranks) -> None: + global _sd + assert _sd is not None + + _sd.machine_ranks = ranks + + +@contextmanager +def override_backend(new_backend: str): + """ + Override distributed backend + + :param new_backend: communicator backend set in this context. + """ + global _sd + assert _sd, "please call init_process_group first" + old_backend = _sd.backend + _sd.backend = new_backend + try: + yield + finally: + _sd.backend = old_backend + + def is_distributed() -> bool: """Return True if the distributed process group has been initialized.""" return _sd is not None diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index c91c7e01af431baa9413240e9818b3402111577b..5ad1b50f3915d910c700ead939b1e8e4e205ea94 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..functional.tensor import copy from ..tensor import Tensor from ..utils.future import Future +from . import group as _group from .functional import _bcast_param, all_reduce_sum, broadcast -from .group import WORLD, Group, group_barrier, is_distributed +from .group import WORLD, Group, group_barrier, is_distributed, override_backend def param_pack_split(inp: Tensor, offsets: list, shapes: list): @@ -118,10 +119,30 @@ def get_offsets(shapes): return offsets +_enable_p2p_cache = None + + +def _check_enable_p2p(): + global _enable_p2p_cache + if _enable_p2p_cache is not None: + return _enable_p2p_cache + cmd = ["nvidia-smi", "topo", "-p2p", "w"] + import subprocess + + output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout + if output.count(b"OK") > 1: + _enable_p2p_cache = True + return True + else: + _enable_p2p_cache = False + return False + + def pack_allreduce_split(pack_list, shapes, group, reduce_method): offsets_val = get_offsets(shapes) offsets = Tensor(offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val) + packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) if reduce_method == "mean": packed_grads /= group.size @@ -207,9 +228,10 @@ class AllreduceCallback: :param reduce_method: the method to reduce gradiants. :param group: communication group. + :param backend: override distributed backend in allreduce """ - def __init__(self, reduce_method: str, group: Group = WORLD): + def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): reduce_method = reduce_method.lower() assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" self._reduce_method = reduce_method @@ -217,6 +239,15 @@ class AllreduceCallback: self._marked_gm = WeakSet() self._param_pack_thd = 10 * 1024 * 1024 self._reset() + if backend is None: + assert _group._sd, "please call init_process_group first" + backend = _group._sd.backend + if backend == "auto": + if group.is_single_machine and not _check_enable_p2p(): + backend = "shm" + else: + backend = "nccl" + self._backend = backend def _reset(self): self._params = [] @@ -231,9 +262,10 @@ class AllreduceCallback: return grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] shapes = [p._tuple_shape for p in self._packing_list[dtype]] - reduced_grads = pack_allreduce_split( - grad_list, shapes, self._group, self._reduce_method - ) + with override_backend(self._backend): + reduced_grads = pack_allreduce_split( + grad_list, shapes, self._group, self._reduce_method + ) for param, grad in zip(self._packing_list[dtype], reduced_grads): self._gradients_dict[param] = grad self._packing_list[dtype] = [] diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 2b78be40025f21dad790dfc0193015ebd5324b98..b043705c48a4e3a12e1d2e326e08fa84416c9977 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -14,7 +14,7 @@ import queue from .. import _exit from ..core._imperative_rt.core2 import full_sync from ..logger import get_logger -from .group import group_barrier, init_process_group +from .group import _set_machine_ranks, group_barrier, init_process_group from .helper import _check_device_initialized, get_device_count_by_fork from .server import Client, Server @@ -34,7 +34,9 @@ def _run_wrapped( device_type, args, kwargs, + backend, queue: mp.Queue, + machine_ranks: list, ): """Init distributed process group and run wrapped function.""" _check_device_initialized(device_type) @@ -44,10 +46,12 @@ def _run_wrapped( world_size=world_size, rank=rank, device=dev, + backend=backend, device_type=device_type, ) # set NCCL_LAUNCH_MODE to avoid deadlock os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" + _set_machine_ranks(machine_ranks) if is_multimachine: group_barrier() ret = func(*args, **kwargs) @@ -67,6 +71,7 @@ class launcher: :param rank_start: start number for rank. :param master_ip: ip address for master node (where the rank 0 is). :param port: server port for distributed server. + :param backend: set default collective communication backend. """ def __new__(cls, *args, **kwargs): @@ -83,6 +88,7 @@ class launcher: master_ip="localhost", port=0, device_type="xpu", + backend="auto", ): self.func = func self.n_gpus = ( @@ -93,6 +99,7 @@ class launcher: self.master_ip = master_ip self.port = port self.device_type = device_type + self.backend = backend # master node create server if self.rank_start == 0: self.server = Server(self.port) @@ -104,6 +111,7 @@ class launcher: procs = [] queue = mp.Queue(self.n_gpus) results = [None] * self.n_gpus + machine_ranks = [i + self.rank_start for i in range(self.n_gpus)] for dev in range(self.n_gpus): p = mp.Process( target=_run_wrapped, @@ -118,7 +126,9 @@ class launcher: self.device_type, args, kwargs, + self.backend, queue, + machine_ranks, ), ) p.start() diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index ba124ecb88ca494c4df5c4cc8a3188a16caa0137..c7d0c0653b0b5165e712e395b9270bd83326ab02 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -11,6 +11,7 @@ #include "megbrain/opr/megray_helper.h" #include "megbrain/comp_node_env.h" +#include "megray/common.h" using namespace mgb; using namespace opr; @@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { return MegRay::MEGRAY_RCCL; } else if (backend == "ucx") { return MegRay::MEGRAY_UCX; + } else if (backend == "shm") { + return MegRay::MEGRAY_SHM; } else { mgb_throw(MegBrainError, "back CollectiveComm backend"); } @@ -90,7 +93,7 @@ std::shared_ptr MegRayCommBuilder::get_megray_comm( if (rank == root) { char* c = MegRay::get_host_ip(); master_ip = std::string(c); - delete c; + delete [] c; port = MegRay::get_free_port(); auto ret = MegRay::create_server(size, port); mgb_assert(ret == MegRay::Status::MEGRAY_OK);