提交 809d5056 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): enable pt shm allreduce

GitOrigin-RevId: 1dd5a02a512b210f2c75afd0062e4bfad1fcdddc
上级 02455941
...@@ -1018,6 +1018,7 @@ endif() ...@@ -1018,6 +1018,7 @@ endif()
if(MGE_WITH_DISTRIBUTED) if(MGE_WITH_DISTRIBUTED)
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE)
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay)
endif() endif()
......
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from mprop import mproperty
from . import group
from .group import ( from .group import (
WORLD, WORLD,
Group, Group,
...@@ -19,7 +22,20 @@ from .group import ( ...@@ -19,7 +22,20 @@ from .group import (
init_process_group, init_process_group,
is_distributed, is_distributed,
new_group, new_group,
override_backend,
) )
from .helper import bcast_list_, make_allreduce_cb, synchronized from .helper import bcast_list_, make_allreduce_cb, synchronized
from .launcher import launcher from .launcher import launcher
from .server import Client, Server from .server import Client, Server
@mproperty
def backend(mod):
assert group._sd, "please call init_process_group first"
return group._sd.backend
@backend.setter
def backend(mod, val):
assert group._sd, "please call init_process_group first"
group._sd.backend = val
...@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply ...@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function, _grad_manager_dict from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device from ..device import get_default_device, what_is_xpu
from ..tensor import Tensor from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank from . import group
from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
__all__ = [ __all__ = [
"reduce_sum", "reduce_sum",
...@@ -34,14 +35,30 @@ __all__ = [ ...@@ -34,14 +35,30 @@ __all__ = [
] ]
_device2backend = {
"gpu": "nccl",
"cuda": "nccl",
"rocm": "rccl",
}
def _backend():
if group._sd.backend == "auto":
return _device2backend[what_is_xpu()]
else:
return group._sd.backend
def collective_comm(inp, mode, group, device): def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions.""" """Helper function for applying collective communication functions."""
assert isinstance(group, Group) assert isinstance(group, Group)
if group is None: if group is None:
return inp return inp
if device is None:
device = ""
addr, port = get_mm_server_addr() addr, port = get_mm_server_addr()
op = CollectiveComm( op = CollectiveComm(
key=group.key, key=group.key + _backend(),
nr_devices=group.size, nr_devices=group.size,
rank=group.rank, rank=group.rank,
is_root=(group.rank == 0), is_root=(group.rank == 0),
...@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): ...@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device):
port=port, port=port,
mode=mode, mode=mode,
dtype=inp.dtype, dtype=inp.dtype,
backend=get_backend(), backend=_backend(),
comp_node=device, comp_node=device,
) )
(result,) = apply(op, inp) (result,) = apply(op, inp)
...@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): ...@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp):
g._refkeeper.append(inp) g._refkeeper.append(inp)
def _dummy_input(shape, dtype, device=""): def _dummy_input(shape, dtype, device=None):
if device == "": if device is None:
device = get_default_device() device = get_default_device()
inp = Tensor(0, dtype=dtype, device=device) inp = Tensor(0, dtype=dtype, device=device)
if len(shape) > 0: if len(shape) > 0:
...@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): ...@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""):
class _ReduceSum(Function): class _ReduceSum(Function):
def __init__(self, group=WORLD, device=""): def __init__(self, group=WORLD, device=None):
self.group = group self.group = group
self.out_device = device self.out_device = device
def forward(self, data): def forward(self, data):
self.in_device = str(data.device) self.in_device = str(data.device)
return collective_comm( return collective_comm(
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
) )
def backward(self, grad): def backward(self, grad):
...@@ -139,7 +156,7 @@ class _ReduceSum(Function): ...@@ -139,7 +156,7 @@ class _ReduceSum(Function):
def reduce_sum( def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create reduce_sum operator for collective communication. Create reduce_sum operator for collective communication.
...@@ -158,14 +175,14 @@ def reduce_sum( ...@@ -158,14 +175,14 @@ def reduce_sum(
class _Broadcast(Function): class _Broadcast(Function):
def __init__(self, group=WORLD, device=""): def __init__(self, group=WORLD, device=None):
self.group = group self.group = group
self.out_device = device self.out_device = device
def forward(self, data): def forward(self, data):
self.in_device = str(data.device) self.in_device = str(data.device)
return collective_comm( return collective_comm(
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
) )
def backward(self, grad): def backward(self, grad):
...@@ -175,7 +192,7 @@ class _Broadcast(Function): ...@@ -175,7 +192,7 @@ class _Broadcast(Function):
def broadcast( def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create broadcast operator for collective communication. Create broadcast operator for collective communication.
...@@ -197,14 +214,14 @@ def broadcast( ...@@ -197,14 +214,14 @@ def broadcast(
def _bcast_param( def _bcast_param(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
) -> Tensor: ) -> Tensor:
mode = CollectiveComm.Mode.BROADCAST mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
def all_gather( def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create all_gather operator for collective communication. Create all_gather operator for collective communication.
...@@ -218,7 +235,7 @@ def all_gather( ...@@ -218,7 +235,7 @@ def all_gather(
def reduce_scatter_sum( def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create reduce_scatter_sum operator for collective communication. Create reduce_scatter_sum operator for collective communication.
...@@ -232,7 +249,7 @@ def reduce_scatter_sum( ...@@ -232,7 +249,7 @@ def reduce_scatter_sum(
def all_reduce_sum( def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create all_reduce_sum operator for collective communication. Create all_reduce_sum operator for collective communication.
...@@ -246,7 +263,7 @@ def all_reduce_sum( ...@@ -246,7 +263,7 @@ def all_reduce_sum(
def all_reduce_max( def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create all_reduce_max operator for collective communication. Create all_reduce_max operator for collective communication.
...@@ -260,7 +277,7 @@ def all_reduce_max( ...@@ -260,7 +277,7 @@ def all_reduce_max(
def all_reduce_min( def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create all_reduce_min operator for collective communication. Create all_reduce_min operator for collective communication.
...@@ -274,7 +291,7 @@ def all_reduce_min( ...@@ -274,7 +291,7 @@ def all_reduce_min(
class _Gather(Function): class _Gather(Function):
def __init__(self, group=WORLD, device=""): def __init__(self, group=WORLD, device=None):
self.group = group self.group = group
self.out_device = device self.out_device = device
...@@ -291,7 +308,7 @@ class _Gather(Function): ...@@ -291,7 +308,7 @@ class _Gather(Function):
def gather( def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create gather operator for collective communication. Create gather operator for collective communication.
...@@ -311,7 +328,7 @@ def gather( ...@@ -311,7 +328,7 @@ def gather(
class _Scatter(Function): class _Scatter(Function):
def __init__(self, group=WORLD, device=""): def __init__(self, group=WORLD, device=None):
self.group = group self.group = group
self.out_device = device self.out_device = device
...@@ -328,7 +345,7 @@ class _Scatter(Function): ...@@ -328,7 +345,7 @@ class _Scatter(Function):
def scatter( def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create scatter operator for collective communication. Create scatter operator for collective communication.
...@@ -350,7 +367,7 @@ def scatter( ...@@ -350,7 +367,7 @@ def scatter(
def all_to_all( def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" """
Create all_to_all operator for collective communication. Create all_to_all operator for collective communication.
...@@ -407,7 +424,7 @@ class _RemoteRecv(Function): ...@@ -407,7 +424,7 @@ class _RemoteRecv(Function):
remote_send(grad, self.op.rank_from) remote_send(grad, self.op.rank_from)
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_send(inp: Tensor, dest_rank: int):
""" """
Send a Tensor to a remote process. Send a Tensor to a remote process.
...@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
op.key = group.key op.key = group.key
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank op.rank_to = dest_rank
op.backend = get_backend() op.backend = _backend()
(out,) = apply(_RemoteSend(op), inp) (out,) = apply(_RemoteSend(op), inp)
_save_output_for_autodiff(inp, out) _save_output_for_autodiff(inp, out)
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
""" """
Receive a Tensor from a remote process. Receive a Tensor from a remote process.
...@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso ...@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso
op.dtype = dtype op.dtype = dtype
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank op.rank_from = src_rank
op.backend = get_backend() op.backend = _backend()
(ret,) = apply(_RemoteRecv(op), inp) (ret,) = apply(_RemoteRecv(op), inp)
if _isscalar: if _isscalar:
......
...@@ -7,8 +7,11 @@ ...@@ -7,8 +7,11 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time import time
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from mprop import mproperty
from ..device import set_default_device, what_is_xpu from ..device import set_default_device, what_is_xpu
from ..random import seed from ..random import seed
from .server import Client, Server from .server import Client, Server
...@@ -26,6 +29,7 @@ class StaticData: ...@@ -26,6 +29,7 @@ class StaticData:
backend = None backend = None
next_stream = None next_stream = None
device_type = None device_type = None
machine_ranks = None
_sd = None _sd = None
...@@ -55,6 +59,7 @@ class Group: ...@@ -55,6 +59,7 @@ class Group:
self.proc_ranks = proc_ranks self.proc_ranks = proc_ranks
self.stream = _sd.next_stream self.stream = _sd.next_stream
_sd.next_stream += 1 _sd.next_stream += 1
self.is_single_machine_cache = None
def check(self, proc_ranks): def check(self, proc_ranks):
assert _sd is not None, "please call init_process_group first" assert _sd is not None, "please call init_process_group first"
...@@ -83,17 +88,23 @@ class Group: ...@@ -83,17 +88,23 @@ class Group:
assert len(self.proc_ranks) > 0, "invalid group" assert len(self.proc_ranks) > 0, "invalid group"
return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
@property
WORLD = Group([]) def is_single_machine(self):
if self.is_single_machine_cache is not None:
return self.is_single_machine_cache
assert _sd is not None, "please call init_process_group first"
for rank in self.proc_ranks:
if rank not in _sd.machine_ranks:
self.is_single_machine_cache = False
return False
self.is_single_machine_cache = True
return True
_device2backend = { WORLD = Group([])
"gpu": "nccl",
"cuda": "nccl",
"rocm": "rccl",
}
_backends = {"nccl", "rccl", "ucx"} _devices = {"gpu", "cuda", "rocm"}
_backends = {"nccl", "rccl", "ucx", "auto"}
def init_process_group( def init_process_group(
...@@ -102,7 +113,7 @@ def init_process_group( ...@@ -102,7 +113,7 @@ def init_process_group(
world_size: int, world_size: int,
rank: int, rank: int,
device: int, device: int,
backend: Optional[str] = None, backend: Optional[str] = "auto",
device_type: str = "xpu", device_type: str = "xpu",
) -> None: ) -> None:
""" """
...@@ -113,10 +124,9 @@ def init_process_group( ...@@ -113,10 +124,9 @@ def init_process_group(
:param world_size: total number of processes participating in the job. :param world_size: total number of processes participating in the job.
:param rank: rank of the current process. :param rank: rank of the current process.
:param device: the GPU device id to bind this process to. :param device: the GPU device id to bind this process to.
:param backend: communicator backend, currently support 'nccl' and 'ucx'. :param backend: communicator backend, currently support 'nccl' and 'shm'.
""" """
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
backend = _device2backend[physical_device_type] if backend is None else backend
if not isinstance(master_ip, str): if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip))) raise TypeError("Expect type str but got {}".format(type(master_ip)))
if not isinstance(port, int): if not isinstance(port, int):
...@@ -131,7 +141,7 @@ def init_process_group( ...@@ -131,7 +141,7 @@ def init_process_group(
raise ValueError( raise ValueError(
"backend should be one of {} but got {}".format(_backends, backend) "backend should be one of {} but got {}".format(_backends, backend)
) )
if physical_device_type not in _device2backend: if physical_device_type not in _devices:
raise ValueError( raise ValueError(
"{} is not a valid distributed device type".format(device_type) "{} is not a valid distributed device type".format(device_type)
) )
...@@ -161,6 +171,30 @@ def init_process_group( ...@@ -161,6 +171,30 @@ def init_process_group(
seed(int(time.time()) + rank) seed(int(time.time()) + rank)
def _set_machine_ranks(ranks) -> None:
global _sd
assert _sd is not None
_sd.machine_ranks = ranks
@contextmanager
def override_backend(new_backend: str):
"""
Override distributed backend
:param new_backend: communicator backend set in this context.
"""
global _sd
assert _sd, "please call init_process_group first"
old_backend = _sd.backend
_sd.backend = new_backend
try:
yield
finally:
_sd.backend = old_backend
def is_distributed() -> bool: def is_distributed() -> bool:
"""Return True if the distributed process group has been initialized.""" """Return True if the distributed process group has been initialized."""
return _sd is not None return _sd is not None
......
...@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit ...@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy from ..functional.tensor import copy
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
from . import group as _group
from .functional import _bcast_param, all_reduce_sum, broadcast from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed from .group import WORLD, Group, group_barrier, is_distributed, override_backend
def param_pack_split(inp: Tensor, offsets: list, shapes: list): def param_pack_split(inp: Tensor, offsets: list, shapes: list):
...@@ -118,10 +119,30 @@ def get_offsets(shapes): ...@@ -118,10 +119,30 @@ def get_offsets(shapes):
return offsets return offsets
_enable_p2p_cache = None
def _check_enable_p2p():
global _enable_p2p_cache
if _enable_p2p_cache is not None:
return _enable_p2p_cache
cmd = ["nvidia-smi", "topo", "-p2p", "w"]
import subprocess
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
if output.count(b"OK") > 1:
_enable_p2p_cache = True
return True
else:
_enable_p2p_cache = False
return False
def pack_allreduce_split(pack_list, shapes, group, reduce_method): def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes) offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val) offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean": if reduce_method == "mean":
packed_grads /= group.size packed_grads /= group.size
...@@ -207,9 +228,10 @@ class AllreduceCallback: ...@@ -207,9 +228,10 @@ class AllreduceCallback:
:param reduce_method: the method to reduce gradiants. :param reduce_method: the method to reduce gradiants.
:param group: communication group. :param group: communication group.
:param backend: override distributed backend in allreduce
""" """
def __init__(self, reduce_method: str, group: Group = WORLD): def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
reduce_method = reduce_method.lower() reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method self._reduce_method = reduce_method
...@@ -217,6 +239,15 @@ class AllreduceCallback: ...@@ -217,6 +239,15 @@ class AllreduceCallback:
self._marked_gm = WeakSet() self._marked_gm = WeakSet()
self._param_pack_thd = 10 * 1024 * 1024 self._param_pack_thd = 10 * 1024 * 1024
self._reset() self._reset()
if backend is None:
assert _group._sd, "please call init_process_group first"
backend = _group._sd.backend
if backend == "auto":
if group.is_single_machine and not _check_enable_p2p():
backend = "shm"
else:
backend = "nccl"
self._backend = backend
def _reset(self): def _reset(self):
self._params = [] self._params = []
...@@ -231,9 +262,10 @@ class AllreduceCallback: ...@@ -231,9 +262,10 @@ class AllreduceCallback:
return return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p._tuple_shape for p in self._packing_list[dtype]] shapes = [p._tuple_shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split( with override_backend(self._backend):
grad_list, shapes, self._group, self._reduce_method reduced_grads = pack_allreduce_split(
) grad_list, shapes, self._group, self._reduce_method
)
for param, grad in zip(self._packing_list[dtype], reduced_grads): for param, grad in zip(self._packing_list[dtype], reduced_grads):
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._packing_list[dtype] = [] self._packing_list[dtype] = []
......
...@@ -14,7 +14,7 @@ import queue ...@@ -14,7 +14,7 @@ import queue
from .. import _exit from .. import _exit
from ..core._imperative_rt.core2 import full_sync from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger from ..logger import get_logger
from .group import group_barrier, init_process_group from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized, get_device_count_by_fork from .helper import _check_device_initialized, get_device_count_by_fork
from .server import Client, Server from .server import Client, Server
...@@ -34,7 +34,9 @@ def _run_wrapped( ...@@ -34,7 +34,9 @@ def _run_wrapped(
device_type, device_type,
args, args,
kwargs, kwargs,
backend,
queue: mp.Queue, queue: mp.Queue,
machine_ranks: list,
): ):
"""Init distributed process group and run wrapped function.""" """Init distributed process group and run wrapped function."""
_check_device_initialized(device_type) _check_device_initialized(device_type)
...@@ -44,10 +46,12 @@ def _run_wrapped( ...@@ -44,10 +46,12 @@ def _run_wrapped(
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
device=dev, device=dev,
backend=backend,
device_type=device_type, device_type=device_type,
) )
# set NCCL_LAUNCH_MODE to avoid deadlock # set NCCL_LAUNCH_MODE to avoid deadlock
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
_set_machine_ranks(machine_ranks)
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
...@@ -67,6 +71,7 @@ class launcher: ...@@ -67,6 +71,7 @@ class launcher:
:param rank_start: start number for rank. :param rank_start: start number for rank.
:param master_ip: ip address for master node (where the rank 0 is). :param master_ip: ip address for master node (where the rank 0 is).
:param port: server port for distributed server. :param port: server port for distributed server.
:param backend: set default collective communication backend.
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
...@@ -83,6 +88,7 @@ class launcher: ...@@ -83,6 +88,7 @@ class launcher:
master_ip="localhost", master_ip="localhost",
port=0, port=0,
device_type="xpu", device_type="xpu",
backend="auto",
): ):
self.func = func self.func = func
self.n_gpus = ( self.n_gpus = (
...@@ -93,6 +99,7 @@ class launcher: ...@@ -93,6 +99,7 @@ class launcher:
self.master_ip = master_ip self.master_ip = master_ip
self.port = port self.port = port
self.device_type = device_type self.device_type = device_type
self.backend = backend
# master node create server # master node create server
if self.rank_start == 0: if self.rank_start == 0:
self.server = Server(self.port) self.server = Server(self.port)
...@@ -104,6 +111,7 @@ class launcher: ...@@ -104,6 +111,7 @@ class launcher:
procs = [] procs = []
queue = mp.Queue(self.n_gpus) queue = mp.Queue(self.n_gpus)
results = [None] * self.n_gpus results = [None] * self.n_gpus
machine_ranks = [i + self.rank_start for i in range(self.n_gpus)]
for dev in range(self.n_gpus): for dev in range(self.n_gpus):
p = mp.Process( p = mp.Process(
target=_run_wrapped, target=_run_wrapped,
...@@ -118,7 +126,9 @@ class launcher: ...@@ -118,7 +126,9 @@ class launcher:
self.device_type, self.device_type,
args, args,
kwargs, kwargs,
self.backend,
queue, queue,
machine_ranks,
), ),
) )
p.start() p.start()
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/opr/megray_helper.h" #include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megray/common.h"
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
...@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { ...@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
return MegRay::MEGRAY_RCCL; return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") { } else if (backend == "ucx") {
return MegRay::MEGRAY_UCX; return MegRay::MEGRAY_UCX;
} else if (backend == "shm") {
return MegRay::MEGRAY_SHM;
} else { } else {
mgb_throw(MegBrainError, "back CollectiveComm backend"); mgb_throw(MegBrainError, "back CollectiveComm backend");
} }
...@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( ...@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
if (rank == root) { if (rank == root) {
char* c = MegRay::get_host_ip(); char* c = MegRay::get_host_ip();
master_ip = std::string(c); master_ip = std::string(c);
delete c; delete [] c;
port = MegRay::get_free_port(); port = MegRay::get_free_port();
auto ret = MegRay::create_server(size, port); auto ret = MegRay::create_server(size, port);
mgb_assert(ret == MegRay::Status::MEGRAY_OK); mgb_assert(ret == MegRay::Status::MEGRAY_OK);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册