Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
809d5056
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
809d5056
编写于
3月 10, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/distributed): enable pt shm allreduce
GitOrigin-RevId: 1dd5a02a512b210f2c75afd0062e4bfad1fcdddc
上级
02455941
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
159 addition
and
46 deletion
+159
-46
CMakeLists.txt
CMakeLists.txt
+1
-0
imperative/python/megengine/distributed/__init__.py
imperative/python/megengine/distributed/__init__.py
+16
-0
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+44
-27
imperative/python/megengine/distributed/group.py
imperative/python/megengine/distributed/group.py
+46
-12
imperative/python/megengine/distributed/helper.py
imperative/python/megengine/distributed/helper.py
+37
-5
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+11
-1
src/opr-mm/impl/megray_helper.cpp
src/opr-mm/impl/megray_helper.cpp
+4
-1
未找到文件。
CMakeLists.txt
浏览文件 @
809d5056
...
@@ -1018,6 +1018,7 @@ endif()
...
@@ -1018,6 +1018,7 @@ endif()
if
(
MGE_WITH_DISTRIBUTED
)
if
(
MGE_WITH_DISTRIBUTED
)
set
(
MEGRAY_WITH_NCCL
${
MGE_WITH_CUDA
}
CACHE BOOL
"Override MegRay option"
FORCE
)
set
(
MEGRAY_WITH_NCCL
${
MGE_WITH_CUDA
}
CACHE BOOL
"Override MegRay option"
FORCE
)
set
(
MEGRAY_WITH_SHM
${
MGE_WITH_CUDA
}
CACHE BOOL
"Override MegRay option"
FORCE
)
set
(
MEGRAY_WITH_RCCL
${
MGE_WITH_ROCM
}
CACHE BOOL
"Override MegRay option"
FORCE
)
set
(
MEGRAY_WITH_RCCL
${
MGE_WITH_ROCM
}
CACHE BOOL
"Override MegRay option"
FORCE
)
add_subdirectory
(
${
PROJECT_SOURCE_DIR
}
/third_party/MegRay
)
add_subdirectory
(
${
PROJECT_SOURCE_DIR
}
/third_party/MegRay
)
endif
()
endif
()
...
...
imperative/python/megengine/distributed/__init__.py
浏览文件 @
809d5056
...
@@ -6,6 +6,9 @@
...
@@ -6,6 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
mprop
import
mproperty
from
.
import
group
from
.group
import
(
from
.group
import
(
WORLD
,
WORLD
,
Group
,
Group
,
...
@@ -19,7 +22,20 @@ from .group import (
...
@@ -19,7 +22,20 @@ from .group import (
init_process_group
,
init_process_group
,
is_distributed
,
is_distributed
,
new_group
,
new_group
,
override_backend
,
)
)
from
.helper
import
bcast_list_
,
make_allreduce_cb
,
synchronized
from
.helper
import
bcast_list_
,
make_allreduce_cb
,
synchronized
from
.launcher
import
launcher
from
.launcher
import
launcher
from
.server
import
Client
,
Server
from
.server
import
Client
,
Server
@
mproperty
def
backend
(
mod
):
assert
group
.
_sd
,
"please call init_process_group first"
return
group
.
_sd
.
backend
@
backend
.
setter
def
backend
(
mod
,
val
):
assert
group
.
_sd
,
"please call init_process_group first"
group
.
_sd
.
backend
=
val
imperative/python/megengine/distributed/functional.py
浏览文件 @
809d5056
...
@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply
...
@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply
from
..core.autodiff.grad
import
Function
,
_grad_manager_dict
from
..core.autodiff.grad
import
Function
,
_grad_manager_dict
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..device
import
get_default_device
from
..device
import
get_default_device
,
what_is_xpu
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
.group
import
WORLD
,
Group
,
get_backend
,
get_client
,
get_mm_server_addr
,
get_rank
from
.
import
group
from
.group
import
WORLD
,
Group
,
get_client
,
get_mm_server_addr
,
get_rank
__all__
=
[
__all__
=
[
"reduce_sum"
,
"reduce_sum"
,
...
@@ -34,14 +35,30 @@ __all__ = [
...
@@ -34,14 +35,30 @@ __all__ = [
]
]
_device2backend
=
{
"gpu"
:
"nccl"
,
"cuda"
:
"nccl"
,
"rocm"
:
"rccl"
,
}
def
_backend
():
if
group
.
_sd
.
backend
==
"auto"
:
return
_device2backend
[
what_is_xpu
()]
else
:
return
group
.
_sd
.
backend
def
collective_comm
(
inp
,
mode
,
group
,
device
):
def
collective_comm
(
inp
,
mode
,
group
,
device
):
"""Helper function for applying collective communication functions."""
"""Helper function for applying collective communication functions."""
assert
isinstance
(
group
,
Group
)
assert
isinstance
(
group
,
Group
)
if
group
is
None
:
if
group
is
None
:
return
inp
return
inp
if
device
is
None
:
device
=
""
addr
,
port
=
get_mm_server_addr
()
addr
,
port
=
get_mm_server_addr
()
op
=
CollectiveComm
(
op
=
CollectiveComm
(
key
=
group
.
key
,
key
=
group
.
key
+
_backend
()
,
nr_devices
=
group
.
size
,
nr_devices
=
group
.
size
,
rank
=
group
.
rank
,
rank
=
group
.
rank
,
is_root
=
(
group
.
rank
==
0
),
is_root
=
(
group
.
rank
==
0
),
...
@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device):
...
@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device):
port
=
port
,
port
=
port
,
mode
=
mode
,
mode
=
mode
,
dtype
=
inp
.
dtype
,
dtype
=
inp
.
dtype
,
backend
=
get
_backend
(),
backend
=
_backend
(),
comp_node
=
device
,
comp_node
=
device
,
)
)
(
result
,)
=
apply
(
op
,
inp
)
(
result
,)
=
apply
(
op
,
inp
)
...
@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp):
...
@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp):
g
.
_refkeeper
.
append
(
inp
)
g
.
_refkeeper
.
append
(
inp
)
def
_dummy_input
(
shape
,
dtype
,
device
=
""
):
def
_dummy_input
(
shape
,
dtype
,
device
=
None
):
if
device
==
""
:
if
device
is
None
:
device
=
get_default_device
()
device
=
get_default_device
()
inp
=
Tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
inp
=
Tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
if
len
(
shape
)
>
0
:
if
len
(
shape
)
>
0
:
...
@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""):
...
@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""):
class
_ReduceSum
(
Function
):
class
_ReduceSum
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
None
):
self
.
group
=
group
self
.
group
=
group
self
.
out_device
=
device
self
.
out_device
=
device
def
forward
(
self
,
data
):
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
REDUCE_SUM
,
self
.
group
,
self
.
out_device
data
,
CollectiveComm
.
Mode
.
REDUCE_SUM
,
self
.
group
,
self
.
out_device
,
)
)
def
backward
(
self
,
grad
):
def
backward
(
self
,
grad
):
...
@@ -139,7 +156,7 @@ class _ReduceSum(Function):
...
@@ -139,7 +156,7 @@ class _ReduceSum(Function):
def
reduce_sum
(
def
reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create reduce_sum operator for collective communication.
Create reduce_sum operator for collective communication.
...
@@ -158,14 +175,14 @@ def reduce_sum(
...
@@ -158,14 +175,14 @@ def reduce_sum(
class
_Broadcast
(
Function
):
class
_Broadcast
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
None
):
self
.
group
=
group
self
.
group
=
group
self
.
out_device
=
device
self
.
out_device
=
device
def
forward
(
self
,
data
):
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
BROADCAST
,
self
.
group
,
self
.
out_device
data
,
CollectiveComm
.
Mode
.
BROADCAST
,
self
.
group
,
self
.
out_device
,
)
)
def
backward
(
self
,
grad
):
def
backward
(
self
,
grad
):
...
@@ -175,7 +192,7 @@ class _Broadcast(Function):
...
@@ -175,7 +192,7 @@ class _Broadcast(Function):
def
broadcast
(
def
broadcast
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create broadcast operator for collective communication.
Create broadcast operator for collective communication.
...
@@ -197,14 +214,14 @@ def broadcast(
...
@@ -197,14 +214,14 @@ def broadcast(
def
_bcast_param
(
def
_bcast_param
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
)
->
Tensor
:
)
->
Tensor
:
mode
=
CollectiveComm
.
Mode
.
BROADCAST
mode
=
CollectiveComm
.
Mode
.
BROADCAST
return
collective_comm
(
inp
,
mode
,
group
,
device
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_gather
(
def
all_gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create all_gather operator for collective communication.
Create all_gather operator for collective communication.
...
@@ -218,7 +235,7 @@ def all_gather(
...
@@ -218,7 +235,7 @@ def all_gather(
def
reduce_scatter_sum
(
def
reduce_scatter_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create reduce_scatter_sum operator for collective communication.
Create reduce_scatter_sum operator for collective communication.
...
@@ -232,7 +249,7 @@ def reduce_scatter_sum(
...
@@ -232,7 +249,7 @@ def reduce_scatter_sum(
def
all_reduce_sum
(
def
all_reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create all_reduce_sum operator for collective communication.
Create all_reduce_sum operator for collective communication.
...
@@ -246,7 +263,7 @@ def all_reduce_sum(
...
@@ -246,7 +263,7 @@ def all_reduce_sum(
def
all_reduce_max
(
def
all_reduce_max
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create all_reduce_max operator for collective communication.
Create all_reduce_max operator for collective communication.
...
@@ -260,7 +277,7 @@ def all_reduce_max(
...
@@ -260,7 +277,7 @@ def all_reduce_max(
def
all_reduce_min
(
def
all_reduce_min
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create all_reduce_min operator for collective communication.
Create all_reduce_min operator for collective communication.
...
@@ -274,7 +291,7 @@ def all_reduce_min(
...
@@ -274,7 +291,7 @@ def all_reduce_min(
class
_Gather
(
Function
):
class
_Gather
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
None
):
self
.
group
=
group
self
.
group
=
group
self
.
out_device
=
device
self
.
out_device
=
device
...
@@ -291,7 +308,7 @@ class _Gather(Function):
...
@@ -291,7 +308,7 @@ class _Gather(Function):
def
gather
(
def
gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create gather operator for collective communication.
Create gather operator for collective communication.
...
@@ -311,7 +328,7 @@ def gather(
...
@@ -311,7 +328,7 @@ def gather(
class
_Scatter
(
Function
):
class
_Scatter
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
None
):
self
.
group
=
group
self
.
group
=
group
self
.
out_device
=
device
self
.
out_device
=
device
...
@@ -328,7 +345,7 @@ class _Scatter(Function):
...
@@ -328,7 +345,7 @@ class _Scatter(Function):
def
scatter
(
def
scatter
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create scatter operator for collective communication.
Create scatter operator for collective communication.
...
@@ -350,7 +367,7 @@ def scatter(
...
@@ -350,7 +367,7 @@ def scatter(
def
all_to_all
(
def
all_to_all
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Create all_to_all operator for collective communication.
Create all_to_all operator for collective communication.
...
@@ -407,7 +424,7 @@ class _RemoteRecv(Function):
...
@@ -407,7 +424,7 @@ class _RemoteRecv(Function):
remote_send
(
grad
,
self
.
op
.
rank_from
)
remote_send
(
grad
,
self
.
op
.
rank_from
)
def
remote_send
(
inp
:
Tensor
,
dest_rank
:
int
)
->
Tensor
:
def
remote_send
(
inp
:
Tensor
,
dest_rank
:
int
):
"""
"""
Send a Tensor to a remote process.
Send a Tensor to a remote process.
...
@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
...
@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
op
.
key
=
group
.
key
op
.
key
=
group
.
key
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
rank_to
=
dest_rank
op
.
backend
=
get
_backend
()
op
.
backend
=
_backend
()
(
out
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
(
out
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
_save_output_for_autodiff
(
inp
,
out
)
_save_output_for_autodiff
(
inp
,
out
)
def
remote_recv
(
src_rank
:
int
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,
)
->
Tensor
:
def
remote_recv
(
src_rank
:
int
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
)
->
Tensor
:
"""
"""
Receive a Tensor from a remote process.
Receive a Tensor from a remote process.
...
@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso
...
@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso
op
.
dtype
=
dtype
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
op
.
rank_from
=
src_rank
op
.
backend
=
get
_backend
()
op
.
backend
=
_backend
()
(
ret
,)
=
apply
(
_RemoteRecv
(
op
),
inp
)
(
ret
,)
=
apply
(
_RemoteRecv
(
op
),
inp
)
if
_isscalar
:
if
_isscalar
:
...
...
imperative/python/megengine/distributed/group.py
浏览文件 @
809d5056
...
@@ -7,8 +7,11 @@
...
@@ -7,8 +7,11 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
time
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
mprop
import
mproperty
from
..device
import
set_default_device
,
what_is_xpu
from
..device
import
set_default_device
,
what_is_xpu
from
..random
import
seed
from
..random
import
seed
from
.server
import
Client
,
Server
from
.server
import
Client
,
Server
...
@@ -26,6 +29,7 @@ class StaticData:
...
@@ -26,6 +29,7 @@ class StaticData:
backend
=
None
backend
=
None
next_stream
=
None
next_stream
=
None
device_type
=
None
device_type
=
None
machine_ranks
=
None
_sd
=
None
_sd
=
None
...
@@ -55,6 +59,7 @@ class Group:
...
@@ -55,6 +59,7 @@ class Group:
self
.
proc_ranks
=
proc_ranks
self
.
proc_ranks
=
proc_ranks
self
.
stream
=
_sd
.
next_stream
self
.
stream
=
_sd
.
next_stream
_sd
.
next_stream
+=
1
_sd
.
next_stream
+=
1
self
.
is_single_machine_cache
=
None
def
check
(
self
,
proc_ranks
):
def
check
(
self
,
proc_ranks
):
assert
_sd
is
not
None
,
"please call init_process_group first"
assert
_sd
is
not
None
,
"please call init_process_group first"
...
@@ -83,17 +88,23 @@ class Group:
...
@@ -83,17 +88,23 @@ class Group:
assert
len
(
self
.
proc_ranks
)
>
0
,
"invalid group"
assert
len
(
self
.
proc_ranks
)
>
0
,
"invalid group"
return
"{}{}:{}"
.
format
(
_sd
.
device_type
,
_sd
.
device
,
self
.
stream
)
return
"{}{}:{}"
.
format
(
_sd
.
device_type
,
_sd
.
device
,
self
.
stream
)
@
property
WORLD
=
Group
([])
def
is_single_machine
(
self
):
if
self
.
is_single_machine_cache
is
not
None
:
return
self
.
is_single_machine_cache
assert
_sd
is
not
None
,
"please call init_process_group first"
for
rank
in
self
.
proc_ranks
:
if
rank
not
in
_sd
.
machine_ranks
:
self
.
is_single_machine_cache
=
False
return
False
self
.
is_single_machine_cache
=
True
return
True
_device2backend
=
{
WORLD
=
Group
([])
"gpu"
:
"nccl"
,
"cuda"
:
"nccl"
,
"rocm"
:
"rccl"
,
}
_backends
=
{
"nccl"
,
"rccl"
,
"ucx"
}
_devices
=
{
"gpu"
,
"cuda"
,
"rocm"
}
_backends
=
{
"nccl"
,
"rccl"
,
"ucx"
,
"auto"
}
def
init_process_group
(
def
init_process_group
(
...
@@ -102,7 +113,7 @@ def init_process_group(
...
@@ -102,7 +113,7 @@ def init_process_group(
world_size
:
int
,
world_size
:
int
,
rank
:
int
,
rank
:
int
,
device
:
int
,
device
:
int
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
"auto"
,
device_type
:
str
=
"xpu"
,
device_type
:
str
=
"xpu"
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -113,10 +124,9 @@ def init_process_group(
...
@@ -113,10 +124,9 @@ def init_process_group(
:param world_size: total number of processes participating in the job.
:param world_size: total number of processes participating in the job.
:param rank: rank of the current process.
:param rank: rank of the current process.
:param device: the GPU device id to bind this process to.
:param device: the GPU device id to bind this process to.
:param backend: communicator backend, currently support 'nccl' and '
ucx
'.
:param backend: communicator backend, currently support 'nccl' and '
shm
'.
"""
"""
physical_device_type
=
what_is_xpu
()
if
device_type
==
"xpu"
else
device_type
physical_device_type
=
what_is_xpu
()
if
device_type
==
"xpu"
else
device_type
backend
=
_device2backend
[
physical_device_type
]
if
backend
is
None
else
backend
if
not
isinstance
(
master_ip
,
str
):
if
not
isinstance
(
master_ip
,
str
):
raise
TypeError
(
"Expect type str but got {}"
.
format
(
type
(
master_ip
)))
raise
TypeError
(
"Expect type str but got {}"
.
format
(
type
(
master_ip
)))
if
not
isinstance
(
port
,
int
):
if
not
isinstance
(
port
,
int
):
...
@@ -131,7 +141,7 @@ def init_process_group(
...
@@ -131,7 +141,7 @@ def init_process_group(
raise
ValueError
(
raise
ValueError
(
"backend should be one of {} but got {}"
.
format
(
_backends
,
backend
)
"backend should be one of {} but got {}"
.
format
(
_backends
,
backend
)
)
)
if
physical_device_type
not
in
_device
2backend
:
if
physical_device_type
not
in
_device
s
:
raise
ValueError
(
raise
ValueError
(
"{} is not a valid distributed device type"
.
format
(
device_type
)
"{} is not a valid distributed device type"
.
format
(
device_type
)
)
)
...
@@ -161,6 +171,30 @@ def init_process_group(
...
@@ -161,6 +171,30 @@ def init_process_group(
seed
(
int
(
time
.
time
())
+
rank
)
seed
(
int
(
time
.
time
())
+
rank
)
def
_set_machine_ranks
(
ranks
)
->
None
:
global
_sd
assert
_sd
is
not
None
_sd
.
machine_ranks
=
ranks
@
contextmanager
def
override_backend
(
new_backend
:
str
):
"""
Override distributed backend
:param new_backend: communicator backend set in this context.
"""
global
_sd
assert
_sd
,
"please call init_process_group first"
old_backend
=
_sd
.
backend
_sd
.
backend
=
new_backend
try
:
yield
finally
:
_sd
.
backend
=
old_backend
def
is_distributed
()
->
bool
:
def
is_distributed
()
->
bool
:
"""Return True if the distributed process group has been initialized."""
"""Return True if the distributed process group has been initialized."""
return
_sd
is
not
None
return
_sd
is
not
None
...
...
imperative/python/megengine/distributed/helper.py
浏览文件 @
809d5056
...
@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
...
@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from
..functional.tensor
import
copy
from
..functional.tensor
import
copy
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
..utils.future
import
Future
from
..utils.future
import
Future
from
.
import
group
as
_group
from
.functional
import
_bcast_param
,
all_reduce_sum
,
broadcast
from
.functional
import
_bcast_param
,
all_reduce_sum
,
broadcast
from
.group
import
WORLD
,
Group
,
group_barrier
,
is_distributed
from
.group
import
WORLD
,
Group
,
group_barrier
,
is_distributed
,
override_backend
def
param_pack_split
(
inp
:
Tensor
,
offsets
:
list
,
shapes
:
list
):
def
param_pack_split
(
inp
:
Tensor
,
offsets
:
list
,
shapes
:
list
):
...
@@ -118,10 +119,30 @@ def get_offsets(shapes):
...
@@ -118,10 +119,30 @@ def get_offsets(shapes):
return
offsets
return
offsets
_enable_p2p_cache
=
None
def
_check_enable_p2p
():
global
_enable_p2p_cache
if
_enable_p2p_cache
is
not
None
:
return
_enable_p2p_cache
cmd
=
[
"nvidia-smi"
,
"topo"
,
"-p2p"
,
"w"
]
import
subprocess
output
=
subprocess
.
run
(
cmd
,
stdout
=
subprocess
.
PIPE
).
stdout
if
output
.
count
(
b
"OK"
)
>
1
:
_enable_p2p_cache
=
True
return
True
else
:
_enable_p2p_cache
=
False
return
False
def
pack_allreduce_split
(
pack_list
,
shapes
,
group
,
reduce_method
):
def
pack_allreduce_split
(
pack_list
,
shapes
,
group
,
reduce_method
):
offsets_val
=
get_offsets
(
shapes
)
offsets_val
=
get_offsets
(
shapes
)
offsets
=
Tensor
(
offsets_val
)
offsets
=
Tensor
(
offsets_val
)
packed_grads
=
param_pack_concat
(
pack_list
,
offsets
,
offsets_val
)
packed_grads
=
param_pack_concat
(
pack_list
,
offsets
,
offsets_val
)
packed_grads
=
all_reduce_sum
(
packed_grads
,
group
,
group
.
comp_node
)
packed_grads
=
all_reduce_sum
(
packed_grads
,
group
,
group
.
comp_node
)
if
reduce_method
==
"mean"
:
if
reduce_method
==
"mean"
:
packed_grads
/=
group
.
size
packed_grads
/=
group
.
size
...
@@ -207,9 +228,10 @@ class AllreduceCallback:
...
@@ -207,9 +228,10 @@ class AllreduceCallback:
:param reduce_method: the method to reduce gradiants.
:param reduce_method: the method to reduce gradiants.
:param group: communication group.
:param group: communication group.
:param backend: override distributed backend in allreduce
"""
"""
def
__init__
(
self
,
reduce_method
:
str
,
group
:
Group
=
WORLD
):
def
__init__
(
self
,
reduce_method
:
str
,
group
:
Group
=
WORLD
,
backend
:
str
=
None
):
reduce_method
=
reduce_method
.
lower
()
reduce_method
=
reduce_method
.
lower
()
assert
reduce_method
in
[
"sum"
,
"mean"
],
"reduce_method should be sum or mean"
assert
reduce_method
in
[
"sum"
,
"mean"
],
"reduce_method should be sum or mean"
self
.
_reduce_method
=
reduce_method
self
.
_reduce_method
=
reduce_method
...
@@ -217,6 +239,15 @@ class AllreduceCallback:
...
@@ -217,6 +239,15 @@ class AllreduceCallback:
self
.
_marked_gm
=
WeakSet
()
self
.
_marked_gm
=
WeakSet
()
self
.
_param_pack_thd
=
10
*
1024
*
1024
self
.
_param_pack_thd
=
10
*
1024
*
1024
self
.
_reset
()
self
.
_reset
()
if
backend
is
None
:
assert
_group
.
_sd
,
"please call init_process_group first"
backend
=
_group
.
_sd
.
backend
if
backend
==
"auto"
:
if
group
.
is_single_machine
and
not
_check_enable_p2p
():
backend
=
"shm"
else
:
backend
=
"nccl"
self
.
_backend
=
backend
def
_reset
(
self
):
def
_reset
(
self
):
self
.
_params
=
[]
self
.
_params
=
[]
...
@@ -231,9 +262,10 @@ class AllreduceCallback:
...
@@ -231,9 +262,10 @@ class AllreduceCallback:
return
return
grad_list
=
[
self
.
_gradients_dict
[
p
]
for
p
in
self
.
_packing_list
[
dtype
]]
grad_list
=
[
self
.
_gradients_dict
[
p
]
for
p
in
self
.
_packing_list
[
dtype
]]
shapes
=
[
p
.
_tuple_shape
for
p
in
self
.
_packing_list
[
dtype
]]
shapes
=
[
p
.
_tuple_shape
for
p
in
self
.
_packing_list
[
dtype
]]
reduced_grads
=
pack_allreduce_split
(
with
override_backend
(
self
.
_backend
):
grad_list
,
shapes
,
self
.
_group
,
self
.
_reduce_method
reduced_grads
=
pack_allreduce_split
(
)
grad_list
,
shapes
,
self
.
_group
,
self
.
_reduce_method
)
for
param
,
grad
in
zip
(
self
.
_packing_list
[
dtype
],
reduced_grads
):
for
param
,
grad
in
zip
(
self
.
_packing_list
[
dtype
],
reduced_grads
):
self
.
_gradients_dict
[
param
]
=
grad
self
.
_gradients_dict
[
param
]
=
grad
self
.
_packing_list
[
dtype
]
=
[]
self
.
_packing_list
[
dtype
]
=
[]
...
...
imperative/python/megengine/distributed/launcher.py
浏览文件 @
809d5056
...
@@ -14,7 +14,7 @@ import queue
...
@@ -14,7 +14,7 @@ import queue
from
..
import
_exit
from
..
import
_exit
from
..core._imperative_rt.core2
import
full_sync
from
..core._imperative_rt.core2
import
full_sync
from
..logger
import
get_logger
from
..logger
import
get_logger
from
.group
import
group_barrier
,
init_process_group
from
.group
import
_set_machine_ranks
,
group_barrier
,
init_process_group
from
.helper
import
_check_device_initialized
,
get_device_count_by_fork
from
.helper
import
_check_device_initialized
,
get_device_count_by_fork
from
.server
import
Client
,
Server
from
.server
import
Client
,
Server
...
@@ -34,7 +34,9 @@ def _run_wrapped(
...
@@ -34,7 +34,9 @@ def _run_wrapped(
device_type
,
device_type
,
args
,
args
,
kwargs
,
kwargs
,
backend
,
queue
:
mp
.
Queue
,
queue
:
mp
.
Queue
,
machine_ranks
:
list
,
):
):
"""Init distributed process group and run wrapped function."""
"""Init distributed process group and run wrapped function."""
_check_device_initialized
(
device_type
)
_check_device_initialized
(
device_type
)
...
@@ -44,10 +46,12 @@ def _run_wrapped(
...
@@ -44,10 +46,12 @@ def _run_wrapped(
world_size
=
world_size
,
world_size
=
world_size
,
rank
=
rank
,
rank
=
rank
,
device
=
dev
,
device
=
dev
,
backend
=
backend
,
device_type
=
device_type
,
device_type
=
device_type
,
)
)
# set NCCL_LAUNCH_MODE to avoid deadlock
# set NCCL_LAUNCH_MODE to avoid deadlock
os
.
environ
[
"NCCL_LAUNCH_MODE"
]
=
"PARALLEL"
os
.
environ
[
"NCCL_LAUNCH_MODE"
]
=
"PARALLEL"
_set_machine_ranks
(
machine_ranks
)
if
is_multimachine
:
if
is_multimachine
:
group_barrier
()
group_barrier
()
ret
=
func
(
*
args
,
**
kwargs
)
ret
=
func
(
*
args
,
**
kwargs
)
...
@@ -67,6 +71,7 @@ class launcher:
...
@@ -67,6 +71,7 @@ class launcher:
:param rank_start: start number for rank.
:param rank_start: start number for rank.
:param master_ip: ip address for master node (where the rank 0 is).
:param master_ip: ip address for master node (where the rank 0 is).
:param port: server port for distributed server.
:param port: server port for distributed server.
:param backend: set default collective communication backend.
"""
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
...
@@ -83,6 +88,7 @@ class launcher:
...
@@ -83,6 +88,7 @@ class launcher:
master_ip
=
"localhost"
,
master_ip
=
"localhost"
,
port
=
0
,
port
=
0
,
device_type
=
"xpu"
,
device_type
=
"xpu"
,
backend
=
"auto"
,
):
):
self
.
func
=
func
self
.
func
=
func
self
.
n_gpus
=
(
self
.
n_gpus
=
(
...
@@ -93,6 +99,7 @@ class launcher:
...
@@ -93,6 +99,7 @@ class launcher:
self
.
master_ip
=
master_ip
self
.
master_ip
=
master_ip
self
.
port
=
port
self
.
port
=
port
self
.
device_type
=
device_type
self
.
device_type
=
device_type
self
.
backend
=
backend
# master node create server
# master node create server
if
self
.
rank_start
==
0
:
if
self
.
rank_start
==
0
:
self
.
server
=
Server
(
self
.
port
)
self
.
server
=
Server
(
self
.
port
)
...
@@ -104,6 +111,7 @@ class launcher:
...
@@ -104,6 +111,7 @@ class launcher:
procs
=
[]
procs
=
[]
queue
=
mp
.
Queue
(
self
.
n_gpus
)
queue
=
mp
.
Queue
(
self
.
n_gpus
)
results
=
[
None
]
*
self
.
n_gpus
results
=
[
None
]
*
self
.
n_gpus
machine_ranks
=
[
i
+
self
.
rank_start
for
i
in
range
(
self
.
n_gpus
)]
for
dev
in
range
(
self
.
n_gpus
):
for
dev
in
range
(
self
.
n_gpus
):
p
=
mp
.
Process
(
p
=
mp
.
Process
(
target
=
_run_wrapped
,
target
=
_run_wrapped
,
...
@@ -118,7 +126,9 @@ class launcher:
...
@@ -118,7 +126,9 @@ class launcher:
self
.
device_type
,
self
.
device_type
,
args
,
args
,
kwargs
,
kwargs
,
self
.
backend
,
queue
,
queue
,
machine_ranks
,
),
),
)
)
p
.
start
()
p
.
start
()
...
...
src/opr-mm/impl/megray_helper.cpp
浏览文件 @
809d5056
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node_env.h"
#include "megray/common.h"
using
namespace
mgb
;
using
namespace
mgb
;
using
namespace
opr
;
using
namespace
opr
;
...
@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
...
@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
return
MegRay
::
MEGRAY_RCCL
;
return
MegRay
::
MEGRAY_RCCL
;
}
else
if
(
backend
==
"ucx"
)
{
}
else
if
(
backend
==
"ucx"
)
{
return
MegRay
::
MEGRAY_UCX
;
return
MegRay
::
MEGRAY_UCX
;
}
else
if
(
backend
==
"shm"
)
{
return
MegRay
::
MEGRAY_SHM
;
}
else
{
}
else
{
mgb_throw
(
MegBrainError
,
"back CollectiveComm backend"
);
mgb_throw
(
MegBrainError
,
"back CollectiveComm backend"
);
}
}
...
@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
...
@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
if
(
rank
==
root
)
{
if
(
rank
==
root
)
{
char
*
c
=
MegRay
::
get_host_ip
();
char
*
c
=
MegRay
::
get_host_ip
();
master_ip
=
std
::
string
(
c
);
master_ip
=
std
::
string
(
c
);
delete
c
;
delete
[]
c
;
port
=
MegRay
::
get_free_port
();
port
=
MegRay
::
get_free_port
();
auto
ret
=
MegRay
::
create_server
(
size
,
port
);
auto
ret
=
MegRay
::
create_server
(
size
,
port
);
mgb_assert
(
ret
==
MegRay
::
Status
::
MEGRAY_OK
);
mgb_assert
(
ret
==
MegRay
::
Status
::
MEGRAY_OK
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录