Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b5ec9dfe
MegEngine
项目概览
MegEngine 天元
/
MegEngine
8 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b5ec9dfe
编写于
4月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/distributed): fix gather scatter reduce broadcast autodiff
GitOrigin-RevId: 1c2250a0795276b696c29d82b68c49eae4653078
上级
a49e202b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
214 addition
and
62 deletion
+214
-62
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+200
-62
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+3
-0
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+11
-0
未找到文件。
imperative/python/megengine/distributed/functional.py
浏览文件 @
b5ec9dfe
...
@@ -8,9 +8,11 @@
...
@@ -8,9 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.core2
import
apply
from
..core.autodiff.grad
import
_grad_manager_dict
from
..core.autodiff.grad
import
Function
,
_grad_manager_dict
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
PyOpBase
,
RemoteRecv
,
RemoteSend
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..device
import
get_default_device
from
..device
import
get_default_device
from
..tensor
import
Tensor
from
..tensor
import
Tensor
...
@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
...
@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
return
result
return
result
def
_save_output_for_autodiff
(
inp
,
out
):
for
g
in
_grad_manager_dict
.
values
():
if
g
.
_is_attached_to
(
inp
):
g
.
_refkeeper
.
append
(
out
)
def
_bcast_has_grad
(
group
,
grad
):
if
group
.
rank
==
0
:
has_grad
=
grad
is
not
None
get_client
().
bcast_val
(
has_grad
,
group
.
key
,
group
.
size
)
else
:
has_grad
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
return
has_grad
def
_bcast_shape_dtype
(
group
,
inp
):
if
group
.
rank
==
0
:
# FIXME in some cases, shape is not available(output of condtake)
shape
=
inp
.
_tuple_shape
dtype
=
np
.
dtype
(
inp
.
dtype
).
name
get_client
().
bcast_val
({
"shape"
:
shape
,
"dtype"
:
dtype
},
group
.
key
,
group
.
size
)
else
:
val
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
shape
=
val
[
"shape"
]
dtype
=
val
[
"dtype"
]
return
shape
,
dtype
def
_bcast_tracer_state
(
group
,
inp
):
if
group
.
rank
==
0
:
tracer_keys
=
[]
for
n
,
g
in
_grad_manager_dict
.
items
():
if
g
.
_is_attached_to
(
inp
):
tracer_keys
.
append
(
n
)
get_client
().
bcast_val
(
tracer_keys
,
group
.
key
,
group
.
size
)
else
:
tracer_keys
=
get_client
().
bcast_val
(
None
,
group
.
key
,
group
.
size
)
for
n
in
tracer_keys
:
g
=
_grad_manager_dict
.
get
(
n
)
if
g
is
not
None
:
g
.
wrt
(
inp
)
g
.
_refkeeper
.
append
(
inp
)
def
_dummy_input
(
shape
,
dtype
,
device
=
""
):
if
device
==
""
:
device
=
get_default_device
()
inp
=
Tensor
(
0
,
dtype
=
dtype
,
device
=
device
)
if
len
(
shape
)
>
0
:
inp
=
inp
.
_broadcast
(
shape
)
return
inp
class
_ReduceSum
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
REDUCE_SUM
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
has_grad
=
_bcast_has_grad
(
self
.
group
,
grad
)
if
has_grad
:
return
broadcast
(
grad
,
self
.
group
,
self
.
in_device
)
def
reduce_sum
(
def
reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
)
->
Tensor
:
...
@@ -75,8 +148,30 @@ def reduce_sum(
...
@@ -75,8 +148,30 @@ def reduce_sum(
:param group: communication group.
:param group: communication group.
:param device: execution device.
:param device: execution device.
"""
"""
mode
=
CollectiveComm
.
Mode
.
REDUCE_SUM
op
=
_ReduceSum
(
group
,
device
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
if
group
.
rank
==
0
:
return
out
else
:
_save_output_for_autodiff
(
inp
,
out
)
class
_Broadcast
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
BROADCAST
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
# TODO backward with a part of grad
if
grad
is
not
None
:
return
reduce_sum
(
grad
,
self
.
group
,
self
.
in_device
)
def
broadcast
(
def
broadcast
(
...
@@ -89,8 +184,16 @@ def broadcast(
...
@@ -89,8 +184,16 @@ def broadcast(
:param group: communication group.
:param group: communication group.
:param device: execution device.
:param device: execution device.
"""
"""
mode
=
CollectiveComm
.
Mode
.
BROADCAST
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
inp
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
if
group
.
rank
!=
0
:
# dummy input to infer shape
inp
=
_dummy_input
(
shape
,
dtype
,
device
)
_bcast_tracer_state
(
group
,
inp
)
op
=
_Broadcast
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
return
out
def
all_gather
(
def
all_gather
(
...
@@ -163,6 +266,23 @@ def all_reduce_min(
...
@@ -163,6 +266,23 @@ def all_reduce_min(
return
collective_comm
(
inp
,
mode
,
group
,
device
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
class
_Gather
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
GATHER
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
has_grad
=
_bcast_has_grad
(
self
.
group
,
grad
)
if
has_grad
:
return
scatter
(
grad
,
self
.
group
,
self
.
in_device
)
def
gather
(
def
gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
)
->
Tensor
:
...
@@ -173,8 +293,31 @@ def gather(
...
@@ -173,8 +293,31 @@ def gather(
:param group: communication group.
:param group: communication group.
:param device: execution device.
:param device: execution device.
"""
"""
mode
=
CollectiveComm
.
Mode
.
GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
op
=
_Gather
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
if
group
.
rank
==
0
:
return
out
else
:
_save_output_for_autodiff
(
inp
,
out
)
class
_Scatter
(
Function
):
def
__init__
(
self
,
group
=
WORLD
,
device
=
""
):
self
.
group
=
group
self
.
out_device
=
device
def
forward
(
self
,
data
):
self
.
in_device
=
str
(
data
.
device
)
return
collective_comm
(
data
,
CollectiveComm
.
Mode
.
SCATTER
,
self
.
group
,
self
.
out_device
)
def
backward
(
self
,
grad
):
# TODO backward with a part of grad
if
grad
is
not
None
:
return
gather
(
grad
,
self
.
group
,
self
.
in_device
)
def
scatter
(
def
scatter
(
...
@@ -187,8 +330,16 @@ def scatter(
...
@@ -187,8 +330,16 @@ def scatter(
:param group: communication group.
:param group: communication group.
:param device: execution device.
:param device: execution device.
"""
"""
mode
=
CollectiveComm
.
Mode
.
SCATTER
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
inp
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
if
group
.
rank
!=
0
:
# dummy input to infer shape
inp
=
_dummy_input
(
shape
,
dtype
,
device
)
_bcast_tracer_state
(
group
,
inp
)
op
=
_Scatter
(
group
,
device
)
(
out
,)
=
apply
(
op
,
inp
)
return
out
def
all_to_all
(
def
all_to_all
(
...
@@ -205,44 +356,46 @@ def all_to_all(
...
@@ -205,44 +356,46 @@ def all_to_all(
return
collective_comm
(
inp
,
mode
,
group
,
device
)
return
collective_comm
(
inp
,
mode
,
group
,
device
)
class
_RemoteSend
(
PyOpBase
):
class
_SendRecvGroup
:
def
__init__
(
self
,
rank_from
,
rank_to
):
self
.
key
=
"{}->{}"
.
format
(
rank_from
,
rank_to
)
self
.
rank_from
=
rank_from
self
.
rank_to
=
rank_to
self
.
size
=
2
@
property
def
rank
(
self
):
if
get_rank
()
==
self
.
rank_from
:
return
0
else
:
return
1
class
_RemoteSend
(
Function
):
def
__init__
(
self
,
op
:
RemoteSend
):
def
__init__
(
self
,
op
:
RemoteSend
):
self
.
op
=
op
self
.
op
=
op
def
_default_rule
(
self
,
data
):
def
forward
(
self
,
data
):
return
apply
(
self
.
op
,
data
)
self
.
device
=
str
(
data
.
device
)
(
self
.
dummy
,)
=
apply
(
self
.
op
,
data
)
def
_grad_rule
(
self
,
data
):
return
self
.
dummy
self
.
dtype
=
data
.
dtype
self
.
shape
=
data
.
shape
self
.
device
=
data
.
device
(
self
.
dummy
,)
=
self
.
_default_rule
(
data
)
return
self
.
dummy
,
self
.
backward
def
backward
(
self
,
grad
):
def
backward
(
self
,
grad
):
assert
grad
is
None
assert
grad
is
None
if
get_client
().
check_is_grad
(
self
.
op
.
key
):
has_grad
=
get_client
().
bcast_val
(
None
,
self
.
op
.
key
,
2
)
return
remote_recv
(
if
has_grad
:
self
.
op
.
rank_to
,
return
remote_recv
(
self
.
op
.
rank_to
,
device
=
self
.
device
,
inp
=
self
.
dummy
,)
self
.
shape
,
self
.
dtype
,
device
=
str
(
self
.
device
),
inp
=
self
.
dummy
,
)
class
_RemoteRecv
(
PyOpBase
):
class
_RemoteRecv
(
Function
):
def
__init__
(
self
,
op
:
RemoteRecv
):
def
__init__
(
self
,
op
:
RemoteRecv
):
self
.
op
=
op
self
.
op
=
op
def
_default_rule
(
self
,
dummy
):
def
forward
(
self
,
dummy
):
return
apply
(
self
.
op
,
dummy
)
return
apply
(
self
.
op
,
dummy
)
def
_grad_rule
(
self
,
dummy
):
return
self
.
_default_rule
(
dummy
),
self
.
backward
def
backward
(
self
,
grad
):
def
backward
(
self
,
grad
):
get_client
().
set_is_grad
(
self
.
op
.
key
,
grad
is
not
None
)
get_client
().
bcast_val
(
grad
is
not
None
,
self
.
op
.
key
,
2
)
if
grad
is
not
None
:
if
grad
is
not
None
:
remote_send
(
grad
,
self
.
op
.
rank_from
)
remote_send
(
grad
,
self
.
op
.
rank_from
)
...
@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
...
@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param inp: tensor to send.
:param dest_rank: destination process rank.
:param dest_rank: destination process rank.
"""
"""
key
=
"{}->{}"
.
format
(
get_rank
(),
dest_rank
)
group
=
_SendRecvGroup
(
get_rank
(),
dest_rank
)
grad_keys
=
{}
_bcast_shape_dtype
(
group
,
inp
)
for
n
,
g
in
_grad_manager_dict
.
items
():
if
g
.
_is_attached_to
(
inp
):
_bcast_tracer_state
(
group
,
inp
)
grad_keys
[
n
]
=
g
get_client
().
set_remote_tracer
(
key
,
grad_keys
)
op
=
RemoteSend
()
op
=
RemoteSend
()
op
.
key
=
key
op
.
key
=
group
.
key
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
op
.
rank_to
=
dest_rank
op
.
backend
=
get_backend
()
op
.
backend
=
get_backend
()
(
dummy
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
(
out
,)
=
apply
(
_RemoteSend
(
op
),
inp
)
for
g
in
grad_keys
.
values
():
_save_output_for_autodiff
(
inp
,
out
)
g
.
_refkeeper
.
append
(
dummy
)
def
remote_recv
(
def
remote_recv
(
src_rank
:
int
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,)
->
Tensor
:
src_rank
:
int
,
shape
:
Tuple
[
int
],
dtype
:
type
,
device
:
Optional
[
str
]
=
None
,
inp
=
None
,
)
->
Tensor
:
"""
"""
Receive a Tensor from a remote process.
Receive a Tensor from a remote process.
:param src_rank: source process rank.
:param src_rank: source process rank.
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
:param inp: dummy input to determine recved tensor type
"""
"""
key
=
"{}->{}"
.
format
(
src_rank
,
get_rank
())
group
=
_SendRecvGroup
(
src_rank
,
get_rank
())
shape
,
dtype
=
_bcast_shape_dtype
(
group
,
None
)
if
device
is
None
:
if
device
is
None
:
device
=
get_default_device
()
device
=
get_default_device
()
# dummy input
# dummy input
if
inp
is
None
:
if
inp
is
None
:
inp
=
Tensor
([
0
],
device
=
device
)
inp
=
Tensor
(
0
,
device
=
device
)
tracer_set
=
get_client
().
check_remote_tracer
(
key
)
_bcast_tracer_state
(
group
,
inp
)
for
n
in
tracer_set
:
g
=
_grad_manager_dict
.
get
(
n
)
if
g
is
not
None
:
g
.
wrt
(
inp
)
g
.
_refkeeper
.
append
(
inp
)
_isscalar
=
False
_isscalar
=
False
if
len
(
shape
)
==
0
:
if
len
(
shape
)
==
0
:
...
@@ -308,7 +446,7 @@ def remote_recv(
...
@@ -308,7 +446,7 @@ def remote_recv(
_isscalar
=
True
_isscalar
=
True
op
=
RemoteRecv
()
op
=
RemoteRecv
()
op
.
key
=
key
op
.
key
=
group
.
key
op
.
cn
=
device
op
.
cn
=
device
op
.
shape
=
shape
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
dtype
=
dtype
...
...
imperative/python/megengine/distributed/launcher.py
浏览文件 @
b5ec9dfe
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
functools
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
queue
import
queue
from
..core._imperative_rt.core2
import
sync
from
..core._imperative_rt.core2
import
sync
...
@@ -43,6 +44,8 @@ def _run_wrapped(
...
@@ -43,6 +44,8 @@ def _run_wrapped(
device
=
dev
,
device
=
dev
,
device_type
=
device_type
,
device_type
=
device_type
,
)
)
# set NCCL_LAUNCH_MODE to avoid deadlock
os
.
environ
[
"NCCL_LAUNCH_MODE"
]
=
"PARALLEL"
if
is_multimachine
:
if
is_multimachine
:
group_barrier
()
group_barrier
()
ret
=
func
(
*
args
,
**
kwargs
)
ret
=
func
(
*
args
,
**
kwargs
)
...
...
imperative/python/megengine/distributed/server.py
浏览文件 @
b5ec9dfe
...
@@ -253,6 +253,17 @@ class Client:
...
@@ -253,6 +253,17 @@ class Client:
"""Get user defined key-value pairs across processes."""
"""Get user defined key-value pairs across processes."""
return
self
.
proxy
.
user_get
(
key
)
return
self
.
proxy
.
user_get
(
key
)
def
bcast_val
(
self
,
val
,
key
,
size
):
if
val
is
not
None
:
self
.
user_set
(
key
+
"_sync"
,
val
)
self
.
group_barrier
(
key
,
size
)
self
.
group_barrier
(
key
,
size
)
else
:
self
.
group_barrier
(
key
,
size
)
val
=
self
.
user_get
(
key
+
"_sync"
)
self
.
group_barrier
(
key
,
size
)
return
val
def
main
(
port
=
0
,
verbose
=
True
):
def
main
(
port
=
0
,
verbose
=
True
):
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录