Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6d367454
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6d367454
编写于
6月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/opr-mm): add param local_grad for collective_comm opr
GitOrigin-RevId: cc120cfb55d67a48dc126d1fd8773fa08a860d32
上级
0ccb965c
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
999 addition
and
410 deletion
+999
-410
python_module/megengine/distributed/__init__.py
python_module/megengine/distributed/__init__.py
+3
-0
python_module/megengine/distributed/functional.py
python_module/megengine/distributed/functional.py
+105
-31
python_module/megengine/distributed/helper.py
python_module/megengine/distributed/helper.py
+16
-5
python_module/megengine/optimizer/optimizer.py
python_module/megengine/optimizer/optimizer.py
+4
-2
python_module/src/cpp/opr_defs.cpp
python_module/src/cpp/opr_defs.cpp
+14
-14
python_module/src/cpp/opr_defs.h
python_module/src/cpp/opr_defs.h
+8
-8
python_module/test/unit/distributed/test_functional.py
python_module/test/unit/distributed/test_functional.py
+11
-11
src/gopt/impl/misc.cpp
src/gopt/impl/misc.cpp
+2
-1
src/gopt/test/misc.cpp
src/gopt/test/misc.cpp
+30
-23
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+150
-99
src/opr-mm/impl/collective_comm.oprdecl
src/opr-mm/impl/collective_comm.oprdecl
+3
-2
src/opr-mm/include/megbrain/opr/collective_comm.h
src/opr-mm/include/megbrain/opr/collective_comm.h
+9
-4
src/opr-mm/test/collective_comm.cpp
src/opr-mm/test/collective_comm.cpp
+641
-208
tools/param_defs/mgb_opr_param_defs.py
tools/param_defs/mgb_opr_param_defs.py
+3
-2
未找到文件。
python_module/megengine/distributed/__init__.py
浏览文件 @
6d367454
...
...
@@ -11,10 +11,13 @@ from .functional import (
all_reduce_max
,
all_reduce_min
,
all_reduce_sum
,
all_to_all
,
bcast_param
,
broadcast
,
gather
,
reduce_scatter_sum
,
reduce_sum
,
scatter
,
)
from
.util
import
(
get_backend
,
...
...
python_module/megengine/distributed/functional.py
浏览文件 @
6d367454
...
...
@@ -9,7 +9,7 @@
from
typing
import
Optional
,
Union
import
megengine._internal
as
mgb
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
Coll
Param
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
Param
from
..core
import
Buffer
,
Parameter
,
Tensor
,
wrap_io_tensor
from
..functional
import
add_update
...
...
@@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs):
return
collective_comm_symvar
(
*
args
,
**
kargs
)
def
_group_check
(
*
args
):
"""Return True when arguments are all None or all not None
"""
l
=
[
val
is
None
for
val
in
args
]
return
len
(
set
(
l
))
<=
1
def
reduce_sum
(
tensor
:
Tensor
,
key
:
str
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
...
...
@@ -35,14 +42,17 @@ def reduce_sum(
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
"""
assert
_group_check
(
key
,
nr_ranks
,
is_root
),
"key, nr_ranks, is_root should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Coll
Param
.
Mode
.
REDUCE_SUM
,
nr_ranks
,
is_root
,
device
=
tensor
.
device
,
tensor
,
key
,
Param
.
Mode
.
REDUCE_SUM
,
nr_ranks
,
is_root
,
device
=
tensor
.
device
,
)
def
gather
(
tensor
:
Tensor
,
key
:
str
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
...
...
@@ -55,20 +65,17 @@ def gather(
:param is_root: whether this is a root node
:param rank: rank of this node
"""
assert
_group_check
(
key
,
nr_ranks
,
is_root
,
rank
),
"key, nr_ranks, is_root, rank should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
GATHER
,
nr_ranks
,
is_root
,
rank
,
device
=
tensor
.
device
,
tensor
,
key
,
Param
.
Mode
.
GATHER
,
nr_ranks
,
is_root
,
rank
,
device
=
tensor
.
device
,
)
def
broadcast
(
tensor
:
Tensor
,
key
:
str
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
...
...
@@ -79,11 +86,12 @@ def broadcast(
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node
"""
if
key
is
None
:
key
=
tensor
.
_symvar
.
name
assert
_group_check
(
key
,
nr_ranks
,
is_root
),
"key, nr_ranks, is_root should be set at the same time"
if
is_root
is
None
:
is_root
=
get_rank
()
==
0
if
is_root
:
inp
=
tensor
else
:
...
...
@@ -92,7 +100,7 @@ def broadcast(
return
_collective_comm
(
inp
,
key
,
Coll
Param
.
Mode
.
BROADCAST
,
Param
.
Mode
.
BROADCAST
,
nr_ranks
,
is_root
,
dtype
=
tensor
.
dtype
,
...
...
@@ -102,7 +110,7 @@ def broadcast(
def
scatter
(
tensor
:
Tensor
,
key
:
str
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
...
...
@@ -115,6 +123,9 @@ def scatter(
:param is_root: whether this is a root node
:param rank: rank of this node
"""
assert
_group_check
(
key
,
nr_ranks
,
is_root
,
rank
),
"key, nr_ranks, is_root, rank should be set at the same time"
if
key
is
None
:
key
=
tensor
.
_symvar
.
name
if
is_root
is
None
:
...
...
@@ -128,7 +139,7 @@ def scatter(
return
_collective_comm
(
inp
,
key
,
Coll
Param
.
Mode
.
SCATTER
,
Param
.
Mode
.
SCATTER
,
nr_ranks
,
is_root
,
rank
,
...
...
@@ -138,7 +149,11 @@ def scatter(
def
all_to_all
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create all_to_all operator for collective communication
...
...
@@ -146,12 +161,22 @@ def all_to_all(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
ALL_TO_ALL
,
nr_ranks
,
rank
=
rank
)
assert
_group_check
(
key
,
nr_ranks
,
rank
),
"key, nr_ranks, rank should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Param
.
Mode
.
ALL_TO_ALL
,
nr_ranks
,
rank
=
rank
,
local_grad
=
local_grad
,
)
def
all_gather
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create all_gather operator for collective communication
...
...
@@ -159,12 +184,22 @@ def all_gather(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
ALL_GATHER
,
nr_ranks
,
rank
=
rank
)
assert
_group_check
(
key
,
nr_ranks
,
rank
),
"key, nr_ranks, rank should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Param
.
Mode
.
ALL_GATHER
,
nr_ranks
,
rank
=
rank
,
local_grad
=
local_grad
)
def
reduce_scatter_sum
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create reduce_scatter_sum operator for collective communication
...
...
@@ -172,45 +207,81 @@ def reduce_scatter_sum(
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node
:param local_grad: whether use local grad
"""
assert
_group_check
(
key
,
nr_ranks
,
rank
),
"key, nr_ranks, rank should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
REDUCE_SCATTER_SUM
,
nr_ranks
,
rank
=
rank
,
tensor
,
key
,
Param
.
Mode
.
REDUCE_SCATTER_SUM
,
nr_ranks
,
rank
=
rank
,
local_grad
=
local_grad
,
)
def
all_reduce_sum
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
)
->
Tensor
:
def
all_reduce_sum
(
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create all_reduce_sum operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
ALL_REDUCE_SUM
,
nr_ranks
)
assert
_group_check
(
key
,
nr_ranks
),
"key, nr_ranks should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Param
.
Mode
.
ALL_REDUCE_SUM
,
nr_ranks
,
local_grad
=
local_grad
)
def
all_reduce_max
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
)
->
Tensor
:
def
all_reduce_max
(
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create all_reduce_max operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
ALL_REDUCE_MAX
,
nr_ranks
)
assert
_group_check
(
key
,
nr_ranks
),
"key, nr_ranks should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Param
.
Mode
.
ALL_REDUCE_MAX
,
nr_ranks
,
local_grad
=
local_grad
)
def
all_reduce_min
(
tensor
:
Tensor
,
key
:
str
,
nr_ranks
:
Optional
[
int
]
=
None
)
->
Tensor
:
def
all_reduce_min
(
tensor
:
Tensor
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
)
->
Tensor
:
"""Create all_reduce_min operator for collective communication
:param tensor: input tensor
:param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
"""
return
_collective_comm
(
tensor
,
key
,
CollParam
.
Mode
.
ALL_REDUCE_MIN
,
nr_ranks
)
assert
_group_check
(
key
,
nr_ranks
),
"key, nr_ranks should be set at the same time"
return
_collective_comm
(
tensor
,
key
,
Param
.
Mode
.
ALL_REDUCE_MIN
,
nr_ranks
,
local_grad
=
local_grad
)
def
bcast_param
(
inp
:
Union
[
Buffer
,
Parameter
],
key
:
str
,
key
:
Optional
[
str
]
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
)
->
None
:
...
...
@@ -223,6 +294,9 @@ def bcast_param(
"""
if
not
is_distributed
():
return
assert
_group_check
(
key
,
nr_ranks
,
is_root
),
"key, nr_ranks, is_root should be set at the same time"
assert
isinstance
(
inp
,
(
Buffer
,
Parameter
))
bcast_res
=
broadcast
(
inp
,
key
,
nr_ranks
,
is_root
)
add_update
(
inp
,
bcast_res
,
alpha
=
0
)
python_module/megengine/distributed/helper.py
浏览文件 @
6d367454
...
...
@@ -11,16 +11,24 @@ from typing import Optional, Union
import
megengine._internal
as
mgb
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
CollParam
from
.util
import
get_backend
,
get_master_ip
,
get_master_port
,
get_rank
,
get_world_size
from
.util
import
(
get_backend
,
get_group_id
,
get_master_ip
,
get_master_port
,
get_rank
,
get_world_size
,
)
def
collective_comm_symvar
(
inp
:
Union
[
mgb
.
SymbolVar
,
mgb
.
CompGraph
],
key
:
str
,
op
:
CollParam
.
Mode
,
key
:
Optional
[
str
]
=
None
,
op
:
CollParam
.
Mode
=
None
,
nr_ranks
:
Optional
[
int
]
=
None
,
is_root
:
Optional
[
bool
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
local_grad
:
Optional
[
bool
]
=
False
,
dtype
:
Optional
[
type
]
=
None
,
device
:
Optional
[
mgb
.
CompNode
]
=
None
,
comp_graph
:
Optional
[
mgb
.
CompGraph
]
=
None
,
...
...
@@ -32,16 +40,19 @@ def collective_comm_symvar(
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this node is root node
:param rank: rank of this node
:param local_grad: whether use local grad
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return
mgb
.
opr
.
collective_comm
(
inp
,
key
=
str
(
key
),
key
=
key
if
key
is
not
None
else
(
"collective_comm_"
+
str
(
get_group_id
())
),
nr_devices
=
nr_ranks
if
nr_ranks
is
not
None
else
get_world_size
(),
is_root
=
is_root
if
is_root
is
not
None
else
(
get_rank
()
==
0
),
rank
=
rank
if
rank
is
not
None
else
-
1
,
rank
=
rank
if
rank
is
not
None
else
get_rank
(),
local_grad
=
local_grad
,
server_addr
=
get_master_ip
(),
port
=
get_master_port
(),
param
=
CollParam
(
mode
=
op
),
...
...
python_module/megengine/optimizer/optimizer.py
浏览文件 @
6d367454
...
...
@@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta):
with
opr_priority_scope
(
cg
,
-
(
2
**
30
)):
# always run all_reduce_mean first except add_update
grad
=
(
all_reduce_sum
(
grad
,
"grad_"
+
str
(
get_group_id
()))
all_reduce_sum
(
grad
,
"grad_"
+
str
(
get_group_id
()),
get_world_size
()
)
/
get_world_size
()
)
with
opr_priority_scope
(
cg
,
-
(
2
**
31
)):
...
...
@@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta):
for
group
in
self
.
param_groups
:
for
param
in
group
[
"params"
]:
bcast_param
(
param
,
"bcast_param_"
+
str
(
key
),
is_root
=
(
get_rank
()
==
0
)
,
param
,
"bcast_param_"
+
str
(
key
),
get_world_size
(),
get_rank
()
==
0
,
)
key
+=
1
...
...
python_module/src/cpp/opr_defs.cpp
浏览文件 @
6d367454
...
...
@@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
SymbolVar
_Opr
::
collective_comm_with_input
(
SymbolVar
inpvar
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
)
{
SymbolVarArray
inputs
(
1
,
inpvar
);
ComputingGraph
*
graph
=
inpvar
.
node
()
->
owner_graph
();
...
...
@@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input(
_dtype
=
npy
::
dtype_np2mgb
(
dtype
);
}
return
CollectiveComm
::
make
(
inputs
,
graph
,
key
,
nr_devices
,
is_root
,
rank
,
group_mgr
,
dev_buffer_arr
,
param
,
_dtype
,
backend
,
config
,
disable
.
get_val
())[
0
];
local_grad
,
group_mgr
,
dev_buffer_arr
,
param
,
_dtype
,
backend
,
config
,
disable
.
get_val
())[
0
];
}
SymbolVar
_Opr
::
collective_comm_without_input
(
CompGraph
&
cg
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
)
{
SymbolVarArray
inputs
;
auto
&
graph
=
cg
.
get
();
...
...
@@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input(
_dtype
=
npy
::
dtype_np2mgb
(
dtype
);
}
return
CollectiveComm
::
make
(
inputs
,
&
graph
,
key
,
nr_devices
,
is_root
,
rank
,
group_mgr
,
dev_buffer_arr
,
param
,
_dtype
,
backend
,
config
,
disable
.
get_val
())[
0
];
local_grad
,
group_mgr
,
dev_buffer_arr
,
param
,
_dtype
,
backend
,
config
,
disable
.
get_val
())[
0
];
}
#else
...
...
@@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
}
SymbolVar
_Opr
::
collective_comm_with_input
(
SymbolVar
inpvar
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
SymbolVar
inpvar
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
)
{
...
...
@@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input(
}
SymbolVar
_Opr
::
collective_comm_without_input
(
CompGraph
&
cg
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
CompGraph
&
cg
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
)
{
...
...
python_module/src/cpp/opr_defs.h
浏览文件 @
6d367454
...
...
@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port,
static
SymbolVar
collective_comm_with_input
(
SymbolVar
inpvar
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
);
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
);
static
SymbolVar
collective_comm_without_input
(
CompGraph
&
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
);
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
const
std
::
string
&
server_addr
,
const
int
port
,
PyObject
*
params
,
PyObject
*
dtype
,
const
std
::
string
&
backend
,
SharedND
*
output_buf
,
const
OperatorNodeConfig
&
config
,
const
SharedScalar
&
disable
);
// misc
static
SymbolVarArray
extern_c_opr_placeholder
(
...
...
python_module/test/unit/distributed/test_functional.py
浏览文件 @
6d367454
...
...
@@ -34,7 +34,7 @@ def test_reduce_sum():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
reduce_sum
(
inp
,
"x"
)
output
=
dist
.
functional
.
reduce_sum
(
inp
)
if
rank
==
0
:
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
else
:
...
...
@@ -70,7 +70,7 @@ def test_gather():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
gather
(
inp
,
"x"
,
is_root
=
(
rank
==
0
),
rank
=
rank
)
output
=
dist
.
functional
.
gather
(
inp
)
if
rank
==
0
:
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
else
:
...
...
@@ -106,7 +106,7 @@ def test_broadcast():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
broadcast
(
inp
,
"x"
)
output
=
dist
.
functional
.
broadcast
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -138,7 +138,7 @@ def test_scatter():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
scatter
(
inp
,
"x"
,
is_root
=
(
rank
==
0
),
rank
=
rank
)
output
=
dist
.
functional
.
scatter
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -174,7 +174,7 @@ def test_all_to_all():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
all_to_all
(
inp
,
"x"
,
rank
=
rank
)
output
=
dist
.
functional
.
all_to_all
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -208,7 +208,7 @@ def test_all_gather():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
all_gather
(
inp
,
"x"
,
rank
=
rank
)
output
=
dist
.
functional
.
all_gather
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -241,7 +241,7 @@ def test_reduce_scatter_sum():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
reduce_scatter_sum
(
inp
,
"x"
,
rank
=
rank
)
output
=
dist
.
functional
.
reduce_scatter_sum
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -278,7 +278,7 @@ def test_all_reduce_sum():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
all_reduce_sum
(
inp
,
"x"
)
output
=
dist
.
functional
.
all_reduce_sum
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -311,7 +311,7 @@ def test_all_reduce_max():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
all_reduce_max
(
inp
,
"x"
)
output
=
dist
.
functional
.
all_reduce_max
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -344,7 +344,7 @@ def test_all_reduce_min():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
tensor
(
data
)
output
=
dist
.
functional
.
all_reduce_min
(
inp
,
"x"
)
output
=
dist
.
functional
.
all_reduce_min
(
inp
)
assert
np
.
allclose
(
output
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
@@ -377,7 +377,7 @@ def test_bcast_param():
return
_init_process_group_wrapper
(
world_size
,
rank
,
rank
,
backend
,
port_queue
)
inp
=
Parameter
(
data
)
dist
.
functional
.
bcast_param
(
inp
,
"x"
)
dist
.
functional
.
bcast_param
(
inp
)
assert
np
.
allclose
(
inp
.
numpy
(),
expect
)
def
check
(
shape
,
backend
):
...
...
src/gopt/impl/misc.cpp
浏览文件 @
6d367454
...
...
@@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
if
(
!
opr
->
same_type
<
opr
::
CollectiveComm
>
())
return
false
;
auto
&
comm
=
opr
->
cast_final_safe
<
opr
::
CollectiveComm
>
();
if
(
comm
.
param
().
mode
!=
opr
::
CollectiveComm
::
Param
::
Mode
::
ALL_REDUCE_SUM
)
return
false
;
if
(
comm
.
local_grad
())
return
false
;
if
(
comm
.
input
().
size
()
!=
1
)
return
false
;
auto
grad
=
comm
.
input
(
0
)
->
owner_opr
();
...
...
@@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs(
std
::
string
key
=
ssprintf
(
"grad_pack_%zu"
,
pack_id
);
auto
param
=
opr
::
CollectiveComm
::
Param
::
Mode
::
ALL_REDUCE_SUM
;
SymbolVar
allreduce
=
opr
::
CollectiveComm
::
make
({
concat
},
graph
,
key
,
info
->
nr_devices
,
info
->
is_root
,
info
->
rank
,
key
,
info
->
nr_devices
,
info
->
is_root
,
info
->
rank
,
false
,
info
->
group_client
,
param
,
info
->
dtype
,
info
->
backend
)[
0
];
// split according to recorded partition
...
...
src/gopt/test/misc.cpp
浏览文件 @
6d367454
...
...
@@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) {
auto
grad3
=
opr
::
VirtualGrad
::
make
(
y1
,
x1
);
auto
mode
=
opr
::
CollectiveComm
::
Param
::
Mode
::
ALL_REDUCE_SUM
;
auto
comm0
=
opr
::
CollectiveComm
::
make
({
grad0
},
graph
.
get
(),
"grad0"
,
2
,
0
,
0
,
client
,
mode
)[
0
];
auto
comm1
=
opr
::
CollectiveComm
::
make
({
grad1
},
graph
.
get
(),
"grad1"
,
2
,
0
,
0
,
client
,
mode
)[
0
];
auto
comm2
=
opr
::
CollectiveComm
::
make
({
grad2
},
graph
.
get
(),
"grad2"
,
2
,
0
,
0
,
client
,
mode
)[
0
];
auto
comm3
=
opr
::
CollectiveComm
::
make
({
grad3
},
graph
.
get
(),
"grad3"
,
2
,
0
,
0
,
client
,
mode
)[
0
];
auto
comm0
=
opr
::
CollectiveComm
::
make
({
grad0
},
graph
.
get
(),
"grad0"
,
2
,
false
,
0
,
false
,
client
,
mode
)[
0
];
auto
comm1
=
opr
::
CollectiveComm
::
make
({
grad1
},
graph
.
get
(),
"grad1"
,
2
,
false
,
0
,
false
,
client
,
mode
)[
0
];
auto
comm2
=
opr
::
CollectiveComm
::
make
({
grad2
},
graph
.
get
(),
"grad2"
,
2
,
false
,
0
,
false
,
client
,
mode
)[
0
];
auto
comm3
=
opr
::
CollectiveComm
::
make
({
grad3
},
graph
.
get
(),
"grad3"
,
2
,
false
,
0
,
false
,
client
,
mode
)[
0
];
gopt
::
GraphOptimizer
()
.
add_pass
<
gopt
::
PackAllReduceScanPass
>
()
...
...
@@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
auto
grad
=
opr
::
VirtualGrad
::
make
(
target
,
wrt
);
auto
comm
=
opr
::
CollectiveComm
::
make
(
{
grad
},
graph
.
get
(),
"key"
,
2
,
0
,
0
,
client
,
opr
::
CollectiveComm
::
Param
::
Mode
::
ALL_REDUCE_SUM
)[
0
]
.
node
()
->
owner_opr
();
auto
comm
=
opr
::
CollectiveComm
::
make
(
{
grad
},
graph
.
get
(),
"key"
,
2
,
false
,
0
,
false
,
client
,
opr
::
CollectiveComm
::
Param
::
Mode
::
ALL_REDUCE_SUM
)[
0
]
.
node
()
->
owner_opr
();
comm
->
cast_final_safe
<
opr
::
CollectiveComm
>
().
set_pack_hash
(
extra_hash
);
...
...
@@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) {
auto
insert_opr
=
[
&
]
(
size_t
size
)
{
auto
dev
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
TensorShape
{
size
/
sizeof
(
float
)});
auto
sd
=
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
dev
);
auto
symvar
=
opr
::
CollectiveComm
::
make
(
{
sd
},
graph
.
get
(),
"key"
,
2
,
0
,
0
,
client
,
mode
)[
0
];
auto
symvar
=
opr
::
CollectiveComm
::
make
(
{
sd
},
graph
.
get
(),
"key"
,
2
,
false
,
0
,
false
,
client
,
mode
)[
0
];
auto
opr
=
symvar
.
node
()
->
owner_opr
();
auto
&
comm
=
opr
->
cast_final_safe
<
opr
::
CollectiveComm
>
();
comm
.
set_pack_hash
(
1
);
...
...
@@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
size_t
nr_devices
=
2
;
uint32_t
rank
=
0
;
uint32_t
root
=
0
;
using
GroupInfo
=
gopt
::
PackAllReduceReplacePass
::
GroupInfo
;
ThinHashMap
<
uint64_t
,
std
::
shared_ptr
<
GroupInfo
>>
group_info
;
...
...
@@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto
insert_opr
=
[
&
]
(
const
TensorShape
&
shape
)
{
auto
dev
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
shape
);
auto
sd
=
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
dev
);
auto
symvar
=
opr
::
CollectiveComm
::
make
({
sd
},
graph
.
get
(),
"key"
,
nr_devices
,
rank
,
root
,
client
,
mode
)[
0
];
auto
symvar
=
opr
::
CollectiveComm
::
make
({
sd
},
graph
.
get
(),
"key"
,
nr_devices
,
false
,
rank
,
false
,
client
,
mode
)[
0
];
auto
opr
=
symvar
.
node
()
->
owner_opr
();
auto
&
comm
=
opr
->
cast_final_safe
<
opr
::
CollectiveComm
>
();
comm
.
set_pack_hash
(
1
);
...
...
@@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto
concat
=
opr
::
Concat
::
make
({
grad_x
.
flatten
(),
grad_y
.
flatten
()},
0
);
std
::
string
key
=
ssprintf
(
"grad_pack_%zu"
,
pack_id
);
auto
allreduce
=
opr
::
CollectiveComm
::
make
({
concat
},
graph
.
get
(),
key
,
nr_devices
,
rank
,
root
,
client
,
mode
)[
0
];
auto
allreduce
=
opr
::
CollectiveComm
::
make
({
concat
},
graph
.
get
(),
key
,
nr_devices
,
false
,
rank
,
false
,
client
,
mode
)[
0
];
std
::
vector
<
size_t
>
partition
;
partition
.
push_back
(
shape_x
.
total_nr_elems
());
...
...
@@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) {
using
Mode
=
opr
::
CollectiveComm
::
Param
::
Mode
;
bool
is_root
=
(
rank
==
0
);
auto
reduced_x
=
opr
::
CollectiveComm
::
make
({
grad_x
},
graph
.
get
(),
"x"
,
2
,
is_root
,
rank
,
client
,
Mode
::
ALL_REDUCE_SUM
)[
0
]
/
2
;
auto
reduced_y
=
opr
::
CollectiveComm
::
make
({
grad_y
},
graph
.
get
(),
"y"
,
2
,
is_root
,
rank
,
client
,
Mode
::
ALL_REDUCE_SUM
)[
0
]
/
2
;
auto
reduced_x
=
opr
::
CollectiveComm
::
make
(
{
grad_x
},
graph
.
get
(),
"x"
,
2
,
is_root
,
rank
,
false
,
client
,
Mode
::
ALL_REDUCE_SUM
)[
0
]
/
2
;
auto
reduced_y
=
opr
::
CollectiveComm
::
make
(
{
grad_y
},
graph
.
get
(),
"y"
,
2
,
is_root
,
rank
,
false
,
client
,
Mode
::
ALL_REDUCE_SUM
)[
0
]
/
2
;
graph
->
options
().
allreduce_pack_max_size
=
5000
;
graph
->
options
().
allreduce_pack_ignore_first
=
0
;
...
...
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
6d367454
...
...
@@ -14,6 +14,8 @@
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/group_manager.h"
...
...
@@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) {
}
}
// anonymous namespace
/* ================= ModeTrait ================= */
class
CollectiveComm
::
ModeTrait
{
class
BROADCAST
;
class
REDUCE_SUM
;
...
...
@@ -132,6 +136,42 @@ public:
return
None
;
}
VarNode
*
full_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
{
auto
mode
=
ModeTrait
::
from_mode
(
opr
->
param
().
mode
).
grad_mode
();
SymbolVarArray
og_syms
;
og_syms
.
push_back
(
out_grad
);
auto
&&
cn
=
opr
->
output
(
0
)
->
comp_node
();
auto
gvar
=
CollectiveComm
::
make
(
og_syms
,
opr
->
owner_graph
(),
opr
->
key
()
+
":grad"
,
opr
->
nr_devices
(),
opr
->
is_root
(),
opr
->
rank
(),
false
,
opr
->
group_client
(),
mode
,
opr
->
dtype
(),
opr
->
backend
(),
{
cn
});
return
gvar
[
0
].
node
();
}
virtual
VarNode
*
local_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
{
mgb_throw
(
MegBrainError
,
"only all_reduce all_to_all all_gather reduce_scatter "
"support local_grad"
);
}
virtual
VarNode
*
grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
{
if
(
opr
->
local_grad
()){
return
local_grad
(
out_grad
,
opr
);
}
else
{
return
full_grad
(
out_grad
,
opr
);
}
}
VarNode
*
zeros
(
mgb
::
cg
::
ComputingGraph
&
graph
,
CompNode
node
,
const
SymbolVar
&
shape
,
DType
dtype
)
const
{
auto
zero
=
SymbolVar
::
make_scalar
(
0
,
graph
,
node
);
auto
zero_tensor
=
opr
::
TypeCvt
::
make
(
zero
,
dtype
).
broadcast
(
shape
);
return
zero_tensor
.
node
();
}
virtual
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
=
0
;
...
...
@@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
}
Mode
grad_mode
()
override
{
return
Mode
::
REDUCE_SCATTER_SUM
;
}
VarNode
*
local_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
auto
nr_devices
=
opr
->
nr_devices
();
auto
rank
=
opr
->
rank
();
opr
::
Subtensor
::
IndexDesc
axis
;
auto
shape0
=
opr
::
GetVarShape
::
make
(
out_grad
,
0
);
axis
.
push_back
({
0
,
shape0
*
rank
/
(
int
)
nr_devices
,
shape0
*
(
rank
+
1
)
/
(
int
)
nr_devices
});
auto
grad
=
opr
::
Subtensor
::
make
(
out_grad
,
axis
);
return
grad
.
node
();
}
};
class
CollectiveComm
::
ModeTrait
::
REDUCE_SCATTER_SUM
:
public
ModeTrait
{
...
...
@@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
}
Mode
grad_mode
()
override
{
return
Mode
::
ALL_GATHER
;
}
};
/* ================= ModeTrait impls ================= */
VarNode
*
local_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
VarNodeArray
grads
;
auto
zeros_tensor
=
zeros
(
*
out_grad
->
owner_graph
(),
out_grad
->
comp_node
(),
opr
::
GetVarShape
::
make
(
out_grad
),
out_grad
->
dtype
());
for
(
size_t
i
=
0
;
i
<
opr
->
nr_devices
();
i
++
)
{
if
(
i
==
opr
->
rank
())
{
grads
.
push_back
(
out_grad
);
}
else
{
grads
.
push_back
(
zeros_tensor
);
}
}
auto
grad
=
opr
::
Concat
::
make
(
grads
,
0
);
return
grad
.
node
();
}
};
class
CollectiveComm
::
ModeTrait
::
ReducedBasedTrait
{
protected:
...
...
@@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
}
Mode
grad_mode
()
override
{
return
Mode
::
ALL_REDUCE_SUM
;
}
public:
VarNode
*
local_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
return
out_grad
;
}
};
class
CollectiveComm
::
ModeTrait
::
ALL_REDUCE_SUM
final
:
public
AllReduceBase
{
...
...
@@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
class
CollectiveComm
::
ModeTrait
::
ALL_REDUCE_MAX
final
:
public
AllReduceBase
{
MegRay
::
ReduceOp
op
()
const
override
{
return
MegRay
::
ReduceOp
::
MEGRAY_MAX
;
}
VarNode
*
grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
VarNode
*
grad
;
if
(
opr
->
local_grad
())
{
grad
=
local_grad
(
out_grad
,
opr
);
}
else
{
grad
=
full_grad
(
out_grad
,
opr
);
}
grad
=
opr
::
Elemwise
::
make
({
opr
->
output
(
0
),
opr
->
input
(
0
),
grad
},
Elemwise
::
Mode
::
COND_LEQ_MOV
)
.
node
();
return
grad
;
}
};
class
CollectiveComm
::
ModeTrait
::
ALL_REDUCE_MIN
final
:
public
AllReduceBase
{
MegRay
::
ReduceOp
op
()
const
override
{
return
MegRay
::
ReduceOp
::
MEGRAY_MIN
;
}
VarNode
*
grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
VarNode
*
grad
;
if
(
opr
->
local_grad
())
{
grad
=
local_grad
(
out_grad
,
opr
);
}
else
{
grad
=
full_grad
(
out_grad
,
opr
);
}
grad
=
opr
::
Elemwise
::
make
({
opr
->
input
(
0
),
opr
->
output
(
0
),
grad
},
Elemwise
::
Mode
::
COND_LEQ_MOV
)
.
node
();
return
grad
;
}
};
class
CollectiveComm
::
ModeTrait
::
ReduceBase
:
public
ReducedBasedTrait
,
...
...
@@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
}
Mode
grad_mode
()
override
{
return
Mode
::
ALL_TO_ALL
;
}
VarNode
*
local_grad
(
VarNode
*
out_grad
,
const
CollectiveComm
*
opr
)
const
override
{
VarNodeArray
grads
;
auto
grad_shape
=
opr
::
GetVarShape
::
make
(
out_grad
);
auto
zeros_tensor
=
zeros
(
*
out_grad
->
owner_graph
(),
out_grad
->
comp_node
(),
grad_shape
,
out_grad
->
dtype
());
auto
nr_devices
=
opr
->
nr_devices
();
auto
rank
=
opr
->
rank
();
opr
::
Subtensor
::
IndexDesc
axis
;
auto
shape0
=
opr
::
GetVarShape
::
make
(
out_grad
,
0
);
axis
.
push_back
({
0
,
shape0
*
rank
/
(
int
)
nr_devices
,
shape0
*
(
rank
+
1
)
/
(
int
)
nr_devices
});
auto
sub_grad
=
opr
::
Subtensor
::
make
(
out_grad
,
axis
);
return
opr
::
SetSubtensor
::
make
(
zeros_tensor
,
sub_grad
,
axis
).
node
();
}
};
CollectiveComm
::
ModeTrait
&
CollectiveComm
::
ModeTrait
::
from_mode
(
Mode
mode
)
{
...
...
@@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
CollectiveComm
::
CollectiveComm
(
VarNodeArray
inputs
,
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
SmallVector
<
std
::
shared_ptr
<
DeviceTensorND
>>&
dev_buffer_arr
,
const
OperatorNodeConfig
&
config
,
const
std
::
shared_ptr
<
DTypeScalar
>&
disable
)
...
...
@@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm(
m_nr_devices
(
nr_devices
),
m_is_root
(
is_root
),
m_rank
(
rank
),
m_local_grad
(
local_grad
),
m_key
(
key
),
m_dev_buffers
(
dev_buffer_arr
),
m_disable
{
disable
}
{
...
...
@@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm(
SymbolVarArray
CollectiveComm
::
make
(
const
SymbolVarArray
&
inputs
,
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
OperatorNodeConfig
&
config
,
const
std
::
shared_ptr
<
DTypeScalar
>&
disable
)
{
SmallVector
<
std
::
shared_ptr
<
DeviceTensorND
>>
dev_buffer_arr
(
nr_devices
,
nullptr
);
return
make
(
inputs
,
graph
,
key
,
nr_devices
,
is_root
,
rank
,
group_client
,
dev_buffer_arr
,
param
,
dtype
,
backend
,
config
);
return
make
(
inputs
,
graph
,
key
,
nr_devices
,
is_root
,
rank
,
local_grad
,
group_client
,
dev_buffer_arr
,
param
,
dtype
,
backend
,
config
);
}
SymbolVarArray
CollectiveComm
::
make
(
const
SymbolVarArray
&
inputs
,
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
SmallVector
<
std
::
shared_ptr
<
DeviceTensorND
>>&
dev_buffer_arr
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
OperatorNodeConfig
&
config
,
const
std
::
shared_ptr
<
DTypeScalar
>&
disable
)
{
auto
inpvars
=
cg
::
to_var_node_array
(
inputs
);
auto
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
CollectiveComm
>
(
inpvars
,
graph
,
key
,
nr_devices
,
is_root
,
rank
,
std
::
move
(
group_client
),
param
,
dtype
,
backend
,
dev_buffer_arr
,
config
,
disable
));
inpvars
,
graph
,
key
,
nr_devices
,
is_root
,
rank
,
local_grad
,
std
::
move
(
group_client
),
param
,
dtype
,
backend
,
dev_buffer_arr
,
config
,
disable
));
mgb_assert
(
!
opr
->
output
().
empty
());
return
cg
::
to_symbol_var_array
(
opr
->
output
());
}
...
...
@@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) {
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
BeforeKernel
>
(
this
,
cn
);
trait
.
exec
(
this
);
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
AfterKernel
>
(
this
,
cn
);
#if CUDART_VERSION < 9000
#pragma message "legacy CUDA; use sync to avoid blocking"
// nccl hangs occasionally without this sync()
cn
.
sync
();
#endif
};
env
.
dispatch_on_comp_node
(
cn
,
runner
);
}
void
CollectiveComm
::
on_output_comp_node_stream_changed
()
{}
VarNodeArray
CollectiveComm
::
grad
(
const
VarNodeArray
&
out_grads
)
const
{
auto
mode
=
ModeTrait
::
from_mode
(
m_param
.
mode
).
grad_mode
();
SymbolVarArray
og_syms
;
if
(
m_param
.
mode
==
Param
::
Mode
::
REDUCE_SUM
)
{
for
(
size_t
i
=
0
;
i
<
output
().
size
();
i
++
)
{
if
(
out_grads
[
i
])
og_syms
.
push_back
(
out_grads
[
i
]);
}
mgb_assert
(
og_syms
.
size
()
==
1
);
}
else
{
for
(
size_t
i
=
0
;
i
<
output
().
size
();
i
++
)
{
if
(
!
out_grads
[
i
])
{
mgb_assert
(
m_param
.
mode
!=
Param
::
Mode
::
REDUCE_SCATTER_SUM
,
"null out grad in CollctiveCommMM currently "
"unsupported when the forward mode is "
"Reduce_Scatter_Sum."
);
DTypeScalar
dval
{
output
(
i
)
->
dtype
()};
dval
.
set_retain_dtype
(
0
);
auto
zeros
=
SymbolVar
::
make_scalar
(
dval
,
*
output
(
i
)
->
owner_graph
(),
output
(
i
)
->
comp_node
())
.
broadcast
(
SymbolVar
(
output
(
i
)).
symshape
());
og_syms
.
push_back
(
zeros
);
}
else
{
og_syms
.
push_back
(
out_grads
[
i
]);
}
}
}
OperatorNodeConfig
::
CompNodeArray
cn_arr
;
if
(
m_param
.
mode
==
Param
::
Mode
::
REDUCE_SUM
)
{
for
(
auto
i
:
input
())
{
cn_arr
.
push_back
(
i
->
comp_node
());
}
}
else
if
(
m_param
.
mode
==
Param
::
Mode
::
BROADCAST
)
{
if
(
!
input
().
empty
())
{
cn_arr
.
push_back
(
input
(
0
)
->
comp_node
());
}
}
auto
gvar
=
CollectiveComm
::
make
(
og_syms
,
owner_graph
(),
m_key
+
":grad"
,
m_nr_devices
,
m_is_root
,
m_rank
,
m_group_client
,
mode
,
m_dtype
,
m_backend
,
OperatorNodeConfig
{}.
comp_node_arr
(
cn_arr
));
if
(
m_param
.
mode
==
Param
::
Mode
::
ALL_REDUCE_MAX
)
{
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
{
gvar
[
i
]
=
Elemwise
::
make
({
output
(
i
),
input
(
i
),
gvar
[
i
]},
Elemwise
::
Mode
::
COND_LEQ_MOV
);
}
}
else
if
(
m_param
.
mode
==
Param
::
Mode
::
ALL_REDUCE_MIN
)
{
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
{
gvar
[
i
]
=
Elemwise
::
make
({
input
(
i
),
output
(
i
),
gvar
[
i
]},
Elemwise
::
Mode
::
COND_LEQ_MOV
);
}
}
else
if
(
m_param
.
mode
==
Param
::
Mode
::
BROADCAST
)
{
if
(
!
input
().
empty
())
{
CompNode
&&
master_out_cn
=
input
(
0
)
->
comp_node
();
SymbolVarArray
rst
;
for
(
auto
i
:
gvar
)
{
if
(
i
.
node
()
->
comp_node
()
==
master_out_cn
)
{
mgb_assert
(
rst
.
empty
());
rst
.
push_back
(
i
);
}
}
gvar
=
rst
;
}
}
return
cg
::
to_var_node_array
(
gvar
);
}
MGB_IMPL_OPR_GRAD
(
CollectiveComm
)
{
return
opr
.
grad
(
out_grad
);
}
void
CollectiveComm
::
init_output_dtype
()
{
if
(
m_dtype
.
valid
())
{
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
{
...
...
@@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() {
}
}
VarNode
*
CollectiveComm
::
grad
(
VarNode
*
out_grad
)
const
{
return
ModeTrait
::
from_mode
(
m_param
.
mode
).
grad
(
out_grad
,
this
);
}
MGB_IMPL_OPR_GRAD
(
CollectiveComm
)
{
mgb_assert
(
out_grad
.
size
()
==
1
,
"CollectiveComm should only have one grad"
);
return
opr
.
grad
(
out_grad
[
0
]);
}
/* ===================== shallow copy ===================== */
namespace
mgb
{
...
...
@@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
opr
::
CollectiveComm
>
();
auto
new_opr
=
CollectiveComm
::
make
(
to_symbol_var_array
(
inputs
),
ctx
.
owner_graph
(
opr_
,
inputs
),
opr
.
key
(),
opr
.
nr_devices
(),
opr
.
is_root
(),
opr
.
rank
(),
opr
.
group_client
(),
opr
.
dev_buffers
(),
opr
.
param
(),
opr
.
dtype
(),
opr
.
backend
(),
config
)[
0
]
.
node
()
->
owner_opr
();
auto
new_opr
=
CollectiveComm
::
make
(
to_symbol_var_array
(
inputs
),
ctx
.
owner_graph
(
opr_
,
inputs
),
opr
.
key
(),
opr
.
nr_devices
(),
opr
.
is_root
(),
opr
.
rank
(),
opr
.
local_grad
(),
opr
.
group_client
(),
opr
.
dev_buffers
(),
opr
.
param
(),
opr
.
dtype
(),
opr
.
backend
(),
config
)[
0
]
.
node
()
->
owner_opr
();
new_opr
->
cast_final_safe
<
opr
::
CollectiveComm
>
().
set_pack_hash
(
opr
.
pack_hash
());
return
new_opr
;
}
...
...
src/opr-mm/impl/collective_comm.oprdecl
浏览文件 @
6d367454
...
...
@@ -8,6 +8,7 @@ decl_raw_opr(
'operation to which this operator belongs.'
,
'int'
),
Doc
(
'is_root'
,
'whether this node is root node'
,
'bool'
),
Doc
(
'rank'
,
'rank of this node, if is -1, generate one'
,
'int'
),
Doc
(
'local_grad'
,
'whether use local grad'
,
'bool'
),
Doc
(
'server_addr'
,
'rpc server ip address'
),
Doc
(
'port'
,
'server rpc listening port'
),
Doc
(
'param'
,
'The only component of *param* is *mode*, which refers to '
...
...
@@ -28,12 +29,12 @@ decl_raw_opr(
body
=
[
'if isinstance(input, _mgb.SymbolVar):'
,
(
' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, '
'nr_devices, is_root, rank,
local_grad,
server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'
),
'else:'
,
' assert isinstance(input, _mgb.CompGraph)'
,
(
' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, '
'nr_devices, is_root, rank,
local_grad,
server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'
)
],
desc
=
(
'collective communication between multiple CompNodes on multiple '
...
...
src/opr-mm/include/megbrain/opr/collective_comm.h
浏览文件 @
6d367454
...
...
@@ -29,8 +29,9 @@ public:
CollectiveComm
(
VarNodeArray
inputs
,
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
,
const
std
::
string
&
backend
,
const
SmallVector
<
std
::
shared_ptr
<
DeviceTensorND
>>&
dev_buffer_arr
,
const
OperatorNodeConfig
&
config
,
const
std
::
shared_ptr
<
DTypeScalar
>&
disable
);
...
...
@@ -38,7 +39,8 @@ public:
static
SymbolVarArray
make
(
const
SymbolVarArray
&
inputs
,
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
SmallVector
<
std
::
shared_ptr
<
DeviceTensorND
>>&
dev_buffer_arr
,
const
Param
&
param
,
const
DType
&
dtype
=
{},
const
std
::
string
&
backend
=
"nccl"
,
...
...
@@ -50,6 +52,7 @@ public:
ComputingGraph
*
const
graph
,
const
std
::
string
&
key
,
const
size_t
nr_devices
,
const
bool
is_root
,
const
int
rank
,
const
bool
local_grad
,
std
::
shared_ptr
<
GroupClient
>
group_client
,
const
Param
&
param
,
const
DType
&
dtype
=
{},
const
std
::
string
&
backend
=
"nccl"
,
...
...
@@ -72,6 +75,7 @@ public:
int
rank
()
const
{
return
m_rank
;
}
int
root
()
const
{
return
m_root
;
}
bool
is_root
()
const
{
return
m_is_root
;
}
bool
local_grad
()
const
{
return
m_local_grad
;
}
//! The key that identifies an NCCL clique.
//! Operators with same keys belong to the same clique.
...
...
@@ -89,7 +93,7 @@ public:
return
m_megray_ctx
;
}
VarNode
Array
grad
(
const
VarNodeArray
&
out_grad
)
const
;
VarNode
*
grad
(
VarNode
*
out_grad
)
const
;
private:
Barrier
m_exec_barrier
;
...
...
@@ -116,6 +120,7 @@ private:
size_t
m_nr_devices
=
0
;
bool
m_is_root
;
int
m_rank
;
bool
m_local_grad
;
std
::
string
m_key
;
//! XXHash generated from m_key
size_t
m_hash
;
...
...
src/opr-mm/test/collective_comm.cpp
浏览文件 @
6d367454
此差异已折叠。
点击以展开。
tools/param_defs/mgb_opr_param_defs.py
浏览文件 @
6d367454
...
...
@@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CollectiveComm'
,
'collective communication between multiple computing '
'nodes on localhost'
)
.
add_enum
(
'Mode'
,
.
add_enum
(
Doc
(
'Mode'
,
'mode of collective communication'
)
,
Doc
(
'REDUCE_SUM'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER'
,
'each output comp node gets the concatenated '
...
...
@@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields(
Doc
(
'ALL_REDUCE_PROD'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER'
,
'concat inputs to one node'
),
Doc
(
'SCATTER'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL'
,
'scatter inputs and gather them on each computing node'
)))
Doc
(
'ALL_TO_ALL'
,
'scatter inputs and gather them on each computing node'
),
name_field
=
'mode'
))
(
pdef
(
'FakeSerializedDType'
,
'HACK: The tag of this param def is actually used for another '
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录