提交 dee5a10a 编写于 作者: M Megvii Engine Team

feat(distributed): auto detect device and backend when init group

GitOrigin-RevId: 90be2d5b4d97f1379b70ffdcbac61269c1d44848
上级 1bec737d
...@@ -12,6 +12,7 @@ from typing import Optional ...@@ -12,6 +12,7 @@ from typing import Optional
from .core._imperative_rt.common import CompNode, DeviceType from .core._imperative_rt.common import CompNode, DeviceType
from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config
from .core._imperative_rt.common import what_is_xpu as _what_is_xpu
__all__ = [ __all__ = [
"is_cuda_available", "is_cuda_available",
...@@ -25,7 +26,7 @@ __all__ = [ ...@@ -25,7 +26,7 @@ __all__ = [
def _valid_device(inp): def _valid_device(inp):
if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp): if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp):
return True return True
return False return False
...@@ -40,21 +41,24 @@ def _str2device_type(type_str: str, allow_unspec: bool = True): ...@@ -40,21 +41,24 @@ def _str2device_type(type_str: str, allow_unspec: bool = True):
return DeviceType.CAMBRICON return DeviceType.CAMBRICON
elif type_str == "ATLAS": elif type_str == "ATLAS":
return DeviceType.ATLAS return DeviceType.ATLAS
elif type_str == "ROCM" or type_str == "AMDGPU":
return DeviceType.ROCM
else: else:
assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu"
return DeviceType.UNSPEC return DeviceType.UNSPEC
_device_type_set = {"cpu", "gpu", "xpu", "rocm"}
def get_device_count(device_type: str) -> int: def get_device_count(device_type: str) -> int:
""" """
Gets number of devices installed on this system. Gets number of devices installed on this system.
:param device_type: device type, one of 'gpu' or 'cpu' :param device_type: device type, one of 'gpu' or 'cpu'
""" """
assert device_type in _device_type_set, "device must be one of {}".format(
device_type_set = ("cpu", "gpu") _device_type_set
assert device_type in device_type_set, "device must be one of {}".format(
device_type_set
) )
device_type = _str2device_type(device_type) device_type = _str2device_type(device_type)
return CompNode._get_device_count(device_type, False) return CompNode._get_device_count(device_type, False)
...@@ -87,6 +91,14 @@ def is_atlas_available() -> bool: ...@@ -87,6 +91,14 @@ def is_atlas_available() -> bool:
return CompNode._get_device_count(t, False) > 0 return CompNode._get_device_count(t, False) > 0
def is_rocm_available() -> bool:
"""Returns whether rocm device is available on this system.
"""
t = _str2device_type("rocm")
return CompNode._get_device_count(t, False) > 0
def set_default_device(device: str = "xpux"): def set_default_device(device: str = "xpux"):
r""" r"""
Sets default computing node. Sets default computing node.
...@@ -151,3 +163,7 @@ def set_prealloc_config( ...@@ -151,3 +163,7 @@ def set_prealloc_config(
assert max_overhead >= 0 assert max_overhead >= 0
assert growth_factor >= 1 assert growth_factor >= 1
_set_prealloc_config(alignment, min_req, max_overhead, growth_factor, device_type) _set_prealloc_config(alignment, min_req, max_overhead, growth_factor, device_type)
def what_is_xpu():
return _what_is_xpu().name.lower()
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..device import set_default_device from ..device import set_default_device, what_is_xpu
from .server import Client, Server from .server import Client, Server
...@@ -23,6 +23,7 @@ class StaticData: ...@@ -23,6 +23,7 @@ class StaticData:
device = None device = None
backend = None backend = None
next_stream = None next_stream = None
device_type = None
_sd = None _sd = None
...@@ -78,19 +79,29 @@ class Group: ...@@ -78,19 +79,29 @@ class Group:
@property @property
def comp_node(self): def comp_node(self):
assert len(self.proc_ranks) > 0, "invalid group" assert len(self.proc_ranks) > 0, "invalid group"
return "gpu{}:{}".format(_sd.device, self.stream) return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
WORLD = Group([]) WORLD = Group([])
_device2backend = {
"gpu": "nccl",
"cuda": "nccl",
"rocm": "rccl",
}
_backends = {"nccl", "rccl", "ucx"}
def init_process_group( def init_process_group(
master_ip: str, master_ip: str,
port: int, port: int,
world_size: int, world_size: int,
rank: int, rank: int,
device: int, device: int,
backend: Optional[str] = "nccl", backend: Optional[str] = None,
device_type: str = "xpu",
) -> None: ) -> None:
""" """
Initialize the distributed process group and specify the device used in the current process Initialize the distributed process group and specify the device used in the current process
...@@ -102,6 +113,8 @@ def init_process_group( ...@@ -102,6 +113,8 @@ def init_process_group(
:param device: the GPU device id to bind this process to. :param device: the GPU device id to bind this process to.
:param backend: communicator backend, currently support 'nccl' and 'ucx'. :param backend: communicator backend, currently support 'nccl' and 'ucx'.
""" """
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
backend = _device2backend[physical_device_type] if backend is None else backend
if not isinstance(master_ip, str): if not isinstance(master_ip, str):
raise TypeError("Expect type str but got {}".format(type(master_ip))) raise TypeError("Expect type str but got {}".format(type(master_ip)))
if not isinstance(port, int): if not isinstance(port, int):
...@@ -112,8 +125,14 @@ def init_process_group( ...@@ -112,8 +125,14 @@ def init_process_group(
raise TypeError("Expect type int but got {}".format(type(rank))) raise TypeError("Expect type int but got {}".format(type(rank)))
if not isinstance(device, int): if not isinstance(device, int):
raise TypeError("Expect type int but got {}".format(type(backend))) raise TypeError("Expect type int but got {}".format(type(backend)))
if not isinstance(backend, str): if backend not in _backends:
raise TypeError("Expect type str but got {}".format(type(backend))) raise ValueError(
"backend should be one of {} but got {}".format(_backends, backend)
)
if physical_device_type not in _device2backend:
raise ValueError(
"{} is not a valid distributed device type".format(device_type)
)
global _sd global _sd
assert _sd is None, "init_process_group should be called only once" assert _sd is None, "init_process_group should be called only once"
...@@ -132,10 +151,11 @@ def init_process_group( ...@@ -132,10 +151,11 @@ def init_process_group(
_sd.device = device _sd.device = device
_sd.backend = backend _sd.backend = backend
_sd.next_stream = 1 _sd.next_stream = 1
_sd.device_type = device_type
WORLD.reset(list(range(world_size))) WORLD.reset(list(range(world_size)))
set_default_device("gpu{}".format(device)) set_default_device("{}{}".format(device_type, device))
def is_distributed() -> bool: def is_distributed() -> bool:
...@@ -182,7 +202,7 @@ def new_group(proc_ranks: List[int]) -> Group: ...@@ -182,7 +202,7 @@ def new_group(proc_ranks: List[int]) -> Group:
return Group(proc_ranks) return Group(proc_ranks)
def group_barrier(group: Optional[Group] = WORLD) -> None: def group_barrier(group: Group = WORLD) -> None:
"""Block until all ranks in the group reach this barrier.""" """Block until all ranks in the group reach this barrier."""
# if running with single node, skip it # if running with single node, skip it
if _sd is None: if _sd is None:
......
...@@ -29,13 +29,19 @@ def _run_wrapped( ...@@ -29,13 +29,19 @@ def _run_wrapped(
world_size, world_size,
rank, rank,
dev, dev,
device_type,
args, args,
kwargs, kwargs,
queue: mp.Queue, queue: mp.Queue,
): ):
"""Init distributed process group and run wrapped function.""" """Init distributed process group and run wrapped function."""
init_process_group( init_process_group(
master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev master_ip=master_ip,
port=port,
world_size=world_size,
rank=rank,
device=dev,
device_type=device_type,
) )
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
...@@ -70,13 +76,17 @@ class launcher: ...@@ -70,13 +76,17 @@ class launcher:
rank_start=0, rank_start=0,
master_ip="localhost", master_ip="localhost",
port=0, port=0,
device_type="xpu",
): ):
self.func = func self.func = func
self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu") self.n_gpus = (
n_gpus if n_gpus is not None else get_device_count_by_fork(device_type)
)
self.world_size = world_size if world_size is not None else self.n_gpus self.world_size = world_size if world_size is not None else self.n_gpus
self.rank_start = rank_start self.rank_start = rank_start
self.master_ip = master_ip self.master_ip = master_ip
self.port = port self.port = port
self.device_type = device_type
# master node create server # master node create server
if self.rank_start == 0: if self.rank_start == 0:
self.server = Server(self.port) self.server = Server(self.port)
...@@ -99,6 +109,7 @@ class launcher: ...@@ -99,6 +109,7 @@ class launcher:
self.world_size, self.world_size,
dev + self.rank_start, dev + self.rank_start,
dev, dev,
self.device_type,
args, args,
kwargs, kwargs,
queue, queue,
......
...@@ -62,8 +62,8 @@ void init_common(py::module m) { ...@@ -62,8 +62,8 @@ void init_common(py::module m) {
return cn.get_mem_status_bytes(); return cn.get_mem_status_bytes();
}) })
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device) .def_static("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device) .def_static("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical) .def("__str__", &CompNode::to_string_logical)
.def("__repr__", [](const CompNode& cn) { .def("__repr__", [](const CompNode& cn) {
return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\"");
...@@ -179,6 +179,10 @@ void init_common(py::module m) { ...@@ -179,6 +179,10 @@ void init_common(py::module m) {
m.def("set_prealloc_config", &CompNode::set_prealloc_config, m.def("set_prealloc_config", &CompNode::set_prealloc_config,
"specifies how to pre-allocate from raw dev allocator"); "specifies how to pre-allocate from raw dev allocator");
m.def("what_is_xpu", []{
return CompNode::Locator::parse("xpux").to_physical().type;
});
init_npy_num_bfloat16(m); init_npy_num_bfloat16(m);
init_npy_num_intbx(m); init_npy_num_intbx(m);
init_dtypes(m); init_dtypes(m);
......
...@@ -16,6 +16,7 @@ import pytest ...@@ -16,6 +16,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
from megengine.device import get_default_device
from megengine.distributed.helper import ( from megengine.distributed.helper import (
get_device_count_by_fork, get_device_count_by_fork,
param_pack_concat, param_pack_concat,
...@@ -87,7 +88,8 @@ def test_new_group(): ...@@ -87,7 +88,8 @@ def test_new_group():
assert group.size == 2 assert group.size == 2
assert group.key == "2,0" assert group.key == "2,0"
assert group.rank == ranks.index(rank) assert group.rank == ranks.index(rank)
assert group.comp_node == "gpu{}:2".format(rank) dt = get_default_device()[:-1]
assert group.comp_node == "{}{}:2".format(dt, rank)
worker() worker()
......
...@@ -236,12 +236,12 @@ def test_io_remote(shape): ...@@ -236,12 +236,12 @@ def test_io_remote(shape):
def worker(val, shape): def worker(val, shape):
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: # remote send if rank == 0: # remote send
x = tensor(val, device="gpu0") x = tensor(val, device="xpu0")
remote_send(x, 1) remote_send(x, 1)
sync() sync()
else: # remote recv else: # remote recv
y = remote_recv(0, shape, np.float32) y = remote_recv(0, shape, np.float32)
assert y.device == "gpu1" assert y.device == get_default_device()
np.testing.assert_almost_equal(val, y.numpy()) np.testing.assert_almost_equal(val, y.numpy())
val = np.random.random_sample(shape).astype("float32") val = np.random.random_sample(shape).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册