Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3f2eac2f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3f2eac2f
编写于
9月 04, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/imperative): move functional/distributed.py to distributed/functional.py
GitOrigin-RevId: 30cf2f514b9abc5e863e1fb26382008391cd607a
上级
b3889938
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
310 addition
and
294 deletion
+310
-294
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+295
-0
imperative/python/megengine/functional/distributed.py
imperative/python/megengine/functional/distributed.py
+15
-294
未找到文件。
imperative/python/megengine/distributed/functional.py
0 → 100644
浏览文件 @
3f2eac2f
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Optional
,
Tuple
from
..core._imperative_rt.ops
import
CollectiveCommMode
from
..core.autodiff.builtin_op_utils
import
builtin_op_get_backward_fn
from
..core.autodiff.grad
import
(
Tracer
,
check_backward_allow_noinput
,
get_grad_managers
,
get_op_has_grad_fn
,
tracer_apply
,
)
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.core
import
apply
from
..core.tensor.tensor
import
Tensor
,
tensor_apply
from
..tensor
import
tensor
from
..device
import
get_default_device
from
.group
import
WORLD
,
Group
,
get_backend
,
get_client
,
get_mm_server_addr
,
get_rank
__all__
=
[
"reduce_sum"
,
"broadcast"
,
"all_gather"
,
"reduce_scatter_sum"
,
"all_reduce_sum"
,
"all_reduce_max"
,
"all_reduce_min"
,
"gather"
,
"scatter"
,
"all_to_all"
,
"remote_send"
,
"remote_recv"
,
]
@
apply
.
add
def
_
(
op
:
RemoteSend
,
*
args
:
Tensor
):
ret
=
tensor_apply
(
op
,
*
args
)
# set extra information
tracer_set
=
dict
()
for
k
in
set
().
union
(
*
(
i
.
_extra_data
for
i
in
args
if
isinstance
(
i
,
Tensor
))):
tracer_set
[
k
.
name
]
=
True
# check tracer_set in remote_recv
get_client
().
set_remote_tracer
(
op
.
key
,
tracer_set
)
return
ret
@
builtin_op_get_backward_fn
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
,
inputs
,
outputs
,
input_requires_grad
):
def
backward
(
*
args
):
return
[
remote_recv
(
op
.
rank_to
,
inputs
[
0
].
shape
,
inputs
[
0
].
dtype
,
str
(
inputs
[
0
].
device
)
)
]
return
backward
,
[
True
]
@
get_op_has_grad_fn
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
):
def
has_grad
(
opnode
,
reached
):
return
get_client
().
check_is_grad
(
op
.
key
)
return
has_grad
@
check_backward_allow_noinput
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
):
return
True
@
builtin_op_get_backward_fn
.
register
(
RemoteRecv
)
def
_
(
op
:
RemoteRecv
,
inputs
,
outputs
,
input_requires_grad
):
def
backward
(
*
output_grads
):
return
[
remote_send
(
output_grads
[
0
],
op
.
rank_from
)]
return
backward
,
[
True
]
@
get_op_has_grad_fn
.
register
(
RemoteRecv
)
def
_
(
op
:
RemoteRecv
):
def
has_grad
(
opnode
,
reached
):
ret
=
False
for
v
in
opnode
.
outputs
:
if
v
()
in
reached
:
ret
=
True
break
get_client
().
set_is_grad
(
op
.
key
,
ret
)
return
ret
return
has_grad
def
collective_comm
(
inp
,
mode
,
group
,
device
):
"""Helper function for applying collective communication functions"""
assert
isinstance
(
group
,
Group
)
if
group
is
None
:
return
inp
op
=
CollectiveComm
()
op
.
key
=
group
.
key
op
.
nr_devices
=
group
.
size
op
.
rank
=
group
.
rank
op
.
is_root
=
op
.
rank
==
0
op
.
local_grad
=
False
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
mode
=
mode
op
.
dtype
=
inp
.
dtype
op
.
backend
=
get_backend
()
op
.
comp_node
=
device
return
apply
(
op
,
inp
)[
0
]
def
reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create reduce_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
REDUCE_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
broadcast
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create broadcast operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
BROADCAST
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_gather operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
reduce_scatter_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create reduce_scatter_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
REDUCE_SCATTER_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_max
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_max operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_MAX
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_min
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_min operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_MIN
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create gather operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
scatter
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create scatter operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
SCATTER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_to_all
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_to_all operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_TO_ALL
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
remote_send
(
inp
:
Tensor
,
dest_rank
:
int
)
->
Tensor
:
"""Send a Tensor to a remote process
:param inp: tensor to send
:param dest_rank: destination process rank
"""
op
=
RemoteSend
()
op
.
key
=
"{}->{}"
.
format
(
get_rank
(),
dest_rank
)
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
return
apply
(
op
,
inp
)[
0
]
def
remote_recv
(
src_rank
:
int
,
shape
:
Tuple
[
int
],
dtype
:
type
,
device
:
Optional
[
str
]
=
None
)
->
Tensor
:
"""Receive a Tensor from a remote process
:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param device: the device to place the received tensor
"""
key
=
"{}->{}"
.
format
(
src_rank
,
get_rank
())
if
device
is
None
:
device
=
get_default_device
()
# dummpy input
inp
=
tensor
([
0
])
tracer_set
=
get_client
().
check_remote_tracer
(
key
)
for
grad_manager
in
get_grad_managers
():
if
grad_manager
.
name
in
tracer_set
:
grad_manager
.
wrt
(
inp
)
op
=
RemoteRecv
()
op
.
key
=
key
op
.
cn
=
device
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
return
apply
(
op
,
inp
)[
0
]
imperative/python/megengine/functional/distributed.py
浏览文件 @
3f2eac2f
...
...
@@ -6,298 +6,19 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Optional
,
Tuple
from
..core._imperative_rt.ops
import
CollectiveCommMode
from
..core.autodiff.builtin_op_utils
import
builtin_op_get_backward_fn
from
..core.autodiff.grad
import
(
Tracer
,
check_backward_allow_noinput
,
get_grad_managers
,
get_op_has_grad_fn
,
tracer_apply
,
# pylint: disable=redefined-builtin
from
..distributed.functional
import
(
all_gather
,
all_reduce_max
,
all_reduce_min
,
all_reduce_sum
,
all_to_all
,
broadcast
,
collective_comm
,
gather
,
reduce_scatter_sum
,
reduce_sum
,
remote_recv
,
remote_send
,
scatter
,
)
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..core.tensor.core
import
apply
from
..core.tensor.tensor
import
Tensor
from
..device
import
get_default_device
from
..distributed.group
import
(
WORLD
,
Group
,
get_backend
,
get_client
,
get_mm_server_addr
,
get_rank
,
)
from
..tensor
import
tensor
__all__
=
[
"reduce_sum"
,
"broadcast"
,
"all_gather"
,
"reduce_scatter_sum"
,
"all_reduce_sum"
,
"all_reduce_max"
,
"all_reduce_min"
,
"gather"
,
"scatter"
,
"all_to_all"
,
"remote_send"
,
"remote_recv"
,
]
@
apply
.
register
()
def
_
(
op
:
RemoteSend
,
*
args
:
Tensor
):
ret
=
apply
.
super
(
op
,
*
args
)
# set extra information
tracer_set
=
dict
()
for
k
in
set
().
union
(
*
(
i
.
_extra_data
for
i
in
args
if
isinstance
(
i
,
Tensor
))):
tracer_set
[
k
.
name
]
=
True
# check tracer_set in remote_recv
get_client
().
set_remote_tracer
(
op
.
key
,
tracer_set
)
return
ret
@
builtin_op_get_backward_fn
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
,
inputs
,
outputs
,
input_requires_grad
):
def
backward
(
*
args
):
return
[
remote_recv
(
op
.
rank_to
,
inputs
[
0
].
shape
,
inputs
[
0
].
dtype
,
str
(
inputs
[
0
].
device
)
)
]
return
backward
,
[
True
]
@
get_op_has_grad_fn
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
):
def
has_grad
(
opnode
,
reached
):
return
get_client
().
check_is_grad
(
op
.
key
)
return
has_grad
@
check_backward_allow_noinput
.
register
(
RemoteSend
)
def
_
(
op
:
RemoteSend
):
return
True
@
builtin_op_get_backward_fn
.
register
(
RemoteRecv
)
def
_
(
op
:
RemoteRecv
,
inputs
,
outputs
,
input_requires_grad
):
def
backward
(
*
output_grads
):
return
[
remote_send
(
output_grads
[
0
],
op
.
rank_from
)]
return
backward
,
[
True
]
@
get_op_has_grad_fn
.
register
(
RemoteRecv
)
def
_
(
op
:
RemoteRecv
):
def
has_grad
(
opnode
,
reached
):
ret
=
False
for
v
in
opnode
.
outputs
:
if
v
()
in
reached
:
ret
=
True
break
get_client
().
set_is_grad
(
op
.
key
,
ret
)
return
ret
return
has_grad
def
collective_comm
(
inp
,
mode
,
group
,
device
):
"""Helper function for applying collective communication functions"""
assert
isinstance
(
group
,
Group
)
if
group
is
None
:
return
inp
op
=
CollectiveComm
()
op
.
key
=
group
.
key
op
.
nr_devices
=
group
.
size
op
.
rank
=
group
.
rank
op
.
is_root
=
op
.
rank
==
0
op
.
local_grad
=
False
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
mode
=
mode
op
.
dtype
=
inp
.
dtype
op
.
backend
=
get_backend
()
op
.
comp_node
=
device
return
apply
(
op
,
inp
)[
0
]
def
reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create reduce_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
REDUCE_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
broadcast
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create broadcast operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
BROADCAST
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_gather operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
reduce_scatter_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create reduce_scatter_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
REDUCE_SCATTER_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_sum
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_sum operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_SUM
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_max
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_max operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_MAX
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_reduce_min
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_reduce_min operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_REDUCE_MIN
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
gather
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create gather operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
GATHER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
scatter
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create scatter operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
SCATTER
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
all_to_all
(
inp
:
Tensor
,
group
:
Optional
[
Group
]
=
WORLD
,
device
:
Optional
[
str
]
=
""
)
->
Tensor
:
"""Create all_to_all operator for collective communication
:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode
=
CollectiveCommMode
.
ALL_TO_ALL
return
collective_comm
(
inp
,
mode
,
group
,
device
)
def
remote_send
(
inp
:
Tensor
,
dest_rank
:
int
)
->
Tensor
:
"""Send a Tensor to a remote process
:param inp: tensor to send
:param dest_rank: destination process rank
"""
op
=
RemoteSend
()
op
.
key
=
"{}->{}"
.
format
(
get_rank
(),
dest_rank
)
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_to
=
dest_rank
return
apply
(
op
,
inp
)[
0
]
def
remote_recv
(
src_rank
:
int
,
shape
:
Tuple
[
int
],
dtype
:
type
,
device
:
Optional
[
str
]
=
None
)
->
Tensor
:
"""Receive a Tensor from a remote process
:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param device: the device to place the received tensor,
if None, use default device
"""
key
=
"{}->{}"
.
format
(
src_rank
,
get_rank
())
if
device
is
None
:
device
=
get_default_device
()
# dummpy input
inp
=
tensor
([
0
])
tracer_set
=
get_client
().
check_remote_tracer
(
key
)
for
grad_manager
in
get_grad_managers
():
if
grad_manager
.
name
in
tracer_set
:
grad_manager
.
wrt
(
inp
)
op
=
RemoteRecv
()
op
.
key
=
key
op
.
cn
=
device
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
rank_from
=
src_rank
return
apply
(
op
,
inp
)[
0
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录