提交 3aef5224 编写于 作者: M Megvii Engine Team

refactor(distributed): remove the shm backend for distributed training

GitOrigin-RevId: ab76f23f9dc6a4452fcde58fac6078f4c24af352
上级 21849d79
...@@ -26,7 +26,7 @@ from .server import Client, Server ...@@ -26,7 +26,7 @@ from .server import Client, Server
@mproperty @mproperty
def backend(mod): def backend(mod):
r"""Get or set backend of collective communication. r"""Get or set backend of collective communication.
Available backends are ['nccl', 'shm', 'rccl'] Available backends are ['nccl', 'rccl']
Examples: Examples:
......
...@@ -95,7 +95,7 @@ class Group: ...@@ -95,7 +95,7 @@ class Group:
WORLD = Group([]) WORLD = Group([])
_devices = {"gpu", "cuda", "rocm"} _devices = {"gpu", "cuda", "rocm"}
_backends = {"nccl", "rccl", "shm", "auto"} _backends = {"nccl", "rccl", "auto"}
def init_process_group( def init_process_group(
...@@ -115,7 +115,7 @@ def init_process_group( ...@@ -115,7 +115,7 @@ def init_process_group(
world_size: total number of processes participating in the job. world_size: total number of processes participating in the job.
rank: rank of the current process. rank: rank of the current process.
device: the GPU device id to bind this process to. device: the GPU device id to bind this process to.
backend: communicator backend, currently support 'nccl' and 'shm'. backend: communicator backend, currently support 'nccl' and 'rccl'.
""" """
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
if not isinstance(master_ip, str): if not isinstance(master_ip, str):
......
...@@ -205,10 +205,7 @@ class AllreduceCallback: ...@@ -205,10 +205,7 @@ class AllreduceCallback:
assert _group._sd, "please call init_process_group first" assert _group._sd, "please call init_process_group first"
backend = _group._sd.backend backend = _group._sd.backend
if backend == "auto": if backend == "auto":
if group.is_single_machine and not _check_enable_p2p(): backend = "nccl"
backend = "shm"
else:
backend = "nccl"
self._backend = backend self._backend = backend
def _reset(self): def _reset(self):
......
...@@ -31,10 +31,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { ...@@ -31,10 +31,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
return MegRay::MEGRAY_RCCL; return MegRay::MEGRAY_RCCL;
} else if (backend == "ucx") { } else if (backend == "ucx") {
return MegRay::MEGRAY_UCX; return MegRay::MEGRAY_UCX;
} else if (backend == "shm") {
return MegRay::MEGRAY_SHM;
} else { } else {
mgb_throw(MegBrainError, "back CollectiveComm backend"); mgb_throw(MegBrainError, "bad CollectiveComm backend");
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册