Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
aea5ac13
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
aea5ac13
编写于
4月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/distributed): fix gather scatter reduce broadcast autodiff
GitOrigin-RevId: 1c2250a0795276b696c29d82b68c49eae4653078
上级
b1bf193e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
214 addition
and
62 deletion
+214
-62
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+200
-62
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+3
-0
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+11
-0
未找到文件。
imperative/python/megengine/distributed/functional.py
浏览文件 @
aea5ac13
...
...
@@ -8,9 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Optional
,
Tuple
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core.autodiff.grad
import
_grad_manager_dict
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
PyOpBase
,
RemoteRecv
,
RemoteSend
from
..core.autodiff.grad
import
Function
,
_grad_manager_dict
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..device
import
get_default_device
from
..tensor
import
Tensor
...
...
@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
return
result
def
_save_output_for_autodiff
(
inp
,
out
):
for
g
in
_grad_manager_dict
.
values
():
if
g
.
_is_attached_to
(
inp
):
g
.
_refkeeper
.
append
(
out
)
def
_bcast_has_grad
(
group
,
grad
):
if
group
.
rank
==
0
:
has_grad
=
grad
is
not
None
get_client
().
bcast_val
(
has_grad
,
group
.
key
,
group
.
size
)
else
:
has_grad
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
return
has_grad
def
_bcast_shape_dtype
(
group
,
inp
):
if
group
.
rank
==
0
:
# FIXME in some cases, shape is not available(output of condtake)
shape
=
inp
.
_tuple_shape
dtype
=
np
.
dtype
(
inp
.
dtype
).
name
get_client
().
bcast_val
({
"shape"
:
shape
,
"dtype"
:
dtype
},
group
.
key
,
group
.
size
)
else
:
val
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
shape
=
val
[
"shape"
]
dtype
=
val
[
"dtype"
]
return
shape
,
dtype
def
_bcast_tracer_state
(
group
,
inp
):
if
group
.
rank
==
0
:
tracer_keys
=
[]
for
n
,
g
in
_grad_manager_dict
.
items
():
if
g
.
_is_attached_to
(
inp
):
tracer_keys
.
append
(
n
)
get_client
().
bcast_val
(
tracer_keys
,
group
.
key
,
group
.
size
)
else
:
tracer_keys
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
for
n
in
tracer_keys
:
g
=
_grad_manager_dict
.
get
(
n
)
if
g
is
not
None
:
g
.
wrt
(
inp
)
g
.
_refkeeper
.
append
(
inp
)
def
_dummy_input
(
shape
,
dtype
,
device
=
""
):
if
device
==
""
:
device
=
get_default_device
()
inp
=
Tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
if
len
(
shape
)
>
0
:
inp
=
inp
.
_broadcast
(
shape
)
return
inp
class
_ReduceSum
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
REDUCE_SUM
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
has_grad
=
_bcast_has_grad
(
self
.
group
,
grad
)
if
has_grad
:
return
broadcast
(
grad
,
self
.
group
,
self
.
in_device
)
def
reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
...
...
@@ -75,8 +148,30 @@ def reduce_sum(
:param group: communication group.
:param device: execution device.
"""
mode
=
CollectiveComm
.
Mode
.
REDUCE_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
op
=
_ReduceSum
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
if
group
.
rank
==
0
:
return
out
else
:
_save_output_for_autodiff
(
inp
,
out
)
class
_Broadcast
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
BROADCAST
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
# TODO backward with a part of grad
if
grad
is
not
None
:
return
reduce_sum
(
grad
,
self
.
group
,
self
.
in_device
)
def
broadcast
(
...
...
@@ -89,8 +184,16 @@ def broadcast(
:param group: communication group.
:param device: execution device.
"""
mode
=
CollectiveComm
.
Mode
.
BROADCAST
return
collective_comm
(
inp
,
mode
,
group
,
device
)
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
inp
)
if
group
.
rank
!=
0
:
# dummy input to infer shape
inp
=
_dummy_input
(
shape
,
dtype
,
device
)
_bcast_tracer_state
(
group
,
inp
)
op
=
_Broadcast
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
return
out
def
all_gather
(
...
...
@@ -163,6 +266,23 @@ def all_reduce_min(
return
collective_comm
(
inp
,
mode
,
group
,
device
)
class
_Gather
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
GATHER
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
has_grad
=
_bcast_has_grad
(
self
.
group
,
grad
)
if
has_grad
:
return
scatter
(
grad
,
self
.
group
,
self
.
in_device
)
def
gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
...
...
@@ -173,8 +293,31 @@ def gather(
:param group: communication group.
:param device: execution device.
"""
mode
=
CollectiveComm
.
Mode
.
GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
op
=
_Gather
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
if
group
.
rank
==
0
:
return
out
else
:
_save_output_for_autodiff
(
inp
,
out
)
class
_Scatter
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
SCATTER
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
# TODO backward with a part of grad
if
grad
is
not
None
:
return
gather
(
grad
,
self
.
group
,
self
.
in_device
)
def
scatter
(
...
...
@@ -187,8 +330,16 @@ def scatter(
:param group: communication group.
:param device: execution device.
"""
mode
=
CollectiveComm
.
Mode
.
SCATTER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
inp
)
if
group
.
rank
!=
0
:
# dummy input to infer shape
inp
=
_dummy_input
(
shape
,
dtype
,
device
)
_bcast_tracer_state
(
group
,
inp
)
op
=
_Scatter
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
return
out
def
all_to_all
(
...
...
@@ -205,44 +356,46 @@ def all_to_all(
return
collective_comm
(
inp
,
mode
,
group
,
device
)
class
_RemoteSend
(
PyOpBase
):
class
_SendRecvGroup
:
def
__init__
(
self
,
rank_from
,
rank_to
):
self
.
key
=
"{}->{}"
.
format
(
rank_from
,
rank_to
)
self
.
rank_from
=
rank_from
self
.
rank_to
=
rank_to
self
.
size
=
2
@
property
def
rank
(
self
):
if
get_rank
()
==
self
.
rank_from
:
return
0
else
:
return
1
class
_RemoteSend
(
Function
):
def
__init__
(
self
,
op
:
RemoteSend
):
self
.
op
=
op
def
_default_rule
(
self
,
data
):
return
apply
(
self
.
op
,
data
)
def
_grad_rule
(
self
,
data
):
self
.
dtype
=
data
.
dtype
self
.
shape
=
data
.
shape
self
.
device
=
data
.
device
(
self
.
dummy
,)
=
self
.
_default_rule
(
data
)
return
self
.
dummy
,
self
.
backward
def
forward
(
self
,
data
):
self
.
device
=
str
(
data
.
device
)
(
self
.
dummy
,)
=
apply
(
self
.
op
,
data
)
return
self
.
dummy
def
backward
(
self
,
grad
):
assert
grad
is
None
if
get_client
().
check_is_grad
(
self
.
op
.
key
):
return
remote_recv
(
self
.
op
.
rank_to
,
self
.
shape
,
self
.
dtype
,
device
=
str
(
self
.
device
),
inp
=
self
.
dummy
,
)
has_grad
=
get_client
().
bcast_val
(
None
,
self
.
op
.
key
,
2
)
if
has_grad
:
return
remote_recv
(
self
.
op
.
rank_to
,
device
=
self
.
device
,
inp
=
self
.
dummy
,)
class
_RemoteRecv
(
PyOpBase
):
class
_RemoteRecv
(
Function
):
def
__init__
(
self
,
op
:
RemoteRecv
):
self
.
op
=
op
def
_default_rule
(
self
,
dummy
):
def
forward
(
self
,
dummy
):
return
apply
(
self
.
op
,
dummy
)
def
_grad_rule
(
self
,
dummy
):
return
self
.
_default_rule
(
dummy
),
self
.
backward
def
backward
(
self
,
grad
):
get_client
().
set_is_grad
(
self
.
op
.
key
,
grad
is
not
None
)
get_client
().
bcast_val
(
grad
is
not
None
,
self
.
op
.
key
,
2
)
if
grad
is
not
None
:
remote_send
(
grad
,
self
.
op
.
rank_from
)
...
...
@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param dest_rank: destination process rank.
"""
key
=
"{}->{}"
.
format
(
get_rank
(),
dest_rank
)
grad_keys
=
{}
for
n
,
g
in
_grad_manager_dict
.
items
():
if
g
.
_is_attached_to
(
inp
):
grad_keys
[
n
]
=
g
get_client
().
set_remote_tracer
(
key
,
grad_keys
)
group
=
_SendRecvGroup
(
get_rank
(),
dest_rank
)
_bcast_shape_dtype
(
group
,
inp
)
_bcast_tracer_state
(
group
,
inp
)
op
=
RemoteSend
()
op
.
key
=
key
op
.
key
=
group
.
key
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
backend
=
get_backend
()
(
dummy
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
(
out
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
for
g
in
grad_keys
.
values
():
g
.
_refkeeper
.
append
(
dummy
)
_save_output_for_autodiff
(
inp
,
out
)
def
remote_recv
(
src_rank
:
int
,
shape
:
Tuple
[
int
],
dtype
:
type
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,
)
->
Tensor
:
def
remote_recv
(
src_rank
:
int
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,)
->
Tensor
:
"""
Receive a Tensor from a remote process.
:param src_rank: source process rank.
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
"""
key
=
"{}->{}"
.
format
(
src_rank
,
get_rank
())
group
=
_SendRecvGroup
(
src_rank
,
get_rank
())
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
None
)
if
device
is
None
:
device
=
get_default_device
()
# dummy input
if
inp
is
None
:
inp
=
Tensor
([
0
],
device
=
device
)
tracer_set
=
get_client
().
check_remote_tracer
(
key
)
for
n
in
tracer_set
:
g
=
_grad_manager_dict
.
get
(
n
)
if
g
is
not
None
:
g
.
wrt
(
inp
)
g
.
_refkeeper
.
append
(
inp
)
inp
=
Tensor
(
0
,
device
=
device
)
_bcast_tracer_state
(
group
,
inp
)
_isscalar
=
False
if
len
(
shape
)
==
0
:
...
...
@@ -308,7 +446,7 @@ def remote_recv(
_isscalar
=
True
op
=
RemoteRecv
()
op
.
key
=
key
op
.
key
=
group
.
key
op
.
cn
=
device
op
.
shape
=
shape
op
.
dtype
=
dtype
...
...
imperative/python/megengine/distributed/launcher.py
浏览文件 @
aea5ac13
...
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
multiprocessing
as
mp
import
os
import
queue
from
..core._imperative_rt.core2
import
sync
...
...
@@ -43,6 +44,8 @@ def _run_wrapped(
device
=
dev
,
device_type
=
device_type
,
)
# set NCCL_LAUNCH_MODE to avoid deadlock
os
.
environ
[
"NCCL_LAUNCH_MODE"
]
=
"PARALLEL"
if
is_multimachine
:
group_barrier
()
ret
=
func
(
*
args
,
**
kwargs
)
...
...
imperative/python/megengine/distributed/server.py
浏览文件 @
aea5ac13
...
...
@@ -253,6 +253,17 @@ class Client:
"""Get user defined key-value pairs across processes."""
return
self
.
proxy
.
user_get
(
key
)
def
bcast_val
(
self
,
val
,
key
,
size
):
if
val
is
not
None
:
self
.
user_set
(
key
+
"_sync"
,
val
)
self
.
group_barrier
(
key
,
size
)
self
.
group_barrier
(
key
,
size
)
else
:
self
.
group_barrier
(
key
,
size
)
val
=
self
.
user_get
(
key
+
"_sync"
)
self
.
group_barrier
(
key
,
size
)
return
val
def
main
(
port
=
0
,
verbose
=
True
):
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录