Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1bce857c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1bce857c
编写于
8月 15, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/opr-mm): use comp_node of config as default in CollectiveComm
GitOrigin-RevId: 6b43c9fc93a5bdcffa12d81179c1d74d6f96ce56
上级
27205461
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
18 addition
and
92 deletion
+18
-92
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+18
-92
未找到文件。
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
1bce857c
...
...
@@ -107,27 +107,9 @@ protected:
}
}
static
void
add_output_var_all2all
(
CollectiveComm
*
opr
)
{
mgb_assert
(
opr
->
nr_devices
()
>=
2
);
auto
pname
=
get_param_name
(
opr
->
param
());
// sublinear would setup opr->config if inputs.size() is 1,
// bypass this situation
mgb_assert
(
!
opr
->
config
().
has_comp_node_set
()
||
opr
->
input
().
size
()
==
1
,
"comp node should not be set in %s mode"
,
pname
);
for
(
auto
i
:
opr
->
input
())
{
opr
->
add_output
(
ssprintf
(
"%s:%s"
,
pname
,
i
->
cname
()))
->
comp_node
(
i
->
comp_node
());
}
}
public:
virtual
~
ModeTrait
()
=
default
;
//! add output var for the opr
virtual
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
inp_cn
)
=
0
;
/*!
* \brief the vars on whose comp node the computing should be performed
* if None, output vars would be used
...
...
@@ -188,11 +170,6 @@ public:
};
class
CollectiveComm
::
ModeTrait
::
ALL_GATHER
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -231,11 +208,6 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
};
class
CollectiveComm
::
ModeTrait
::
REDUCE_SCATTER_SUM
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -292,11 +264,6 @@ protected:
class
CollectiveComm
::
ModeTrait
::
AllReduceBase
:
public
ReducedBasedTrait
,
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -368,11 +335,6 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
class
CollectiveComm
::
ModeTrait
::
ReduceBase
:
public
ReducedBasedTrait
,
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
inp_cn
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -413,19 +375,6 @@ class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
};
class
CollectiveComm
::
ModeTrait
::
BROADCAST
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
if
(
opr
->
input
().
size
()
>
0
)
{
add_output_var_all2all
(
opr
);
return
;
}
const
auto
&
cns
=
opr
->
config
().
comp_node
();
mgb_assert
(
cns
.
size
()
==
1
,
"exactly one comp_node expected, got %zu"
,
cns
.
size
());
auto
pname
=
get_param_name
(
opr
->
param
());
opr
->
add_output
(
ssprintf
(
"%s:%s"
,
pname
,
opr
->
key
().
c_str
()))
->
comp_node
(
cns
[
0
]);
}
void
get_output_var_shape
(
const
CollectiveComm
*
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -462,11 +411,6 @@ class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
};
class
CollectiveComm
::
ModeTrait
::
GATHER
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -501,19 +445,6 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
};
class
CollectiveComm
::
ModeTrait
::
SCATTER
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
if
(
opr
->
input
().
size
()
>
0
)
{
add_output_var_all2all
(
opr
);
return
;
}
const
auto
&
cns
=
opr
->
config
().
comp_node
();
mgb_assert
(
cns
.
size
()
==
1
,
"exactly one comp_node expected, got %zu"
,
cns
.
size
());
auto
pname
=
get_param_name
(
opr
->
param
());
opr
->
add_output
(
ssprintf
(
"%s:%s"
,
pname
,
opr
->
key
().
c_str
()))
->
comp_node
(
cns
[
0
]);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -537,11 +468,6 @@ class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
};
class
CollectiveComm
::
ModeTrait
::
ALL_TO_ALL
:
public
ModeTrait
{
void
add_output_var
(
CollectiveComm
*
opr
,
const
CompNode
::
UnorderedSet
&
)
override
{
add_output_var_all2all
(
opr
);
}
void
get_output_var_shape
(
const
CollectiveComm
*
opr
,
const
TensorShapeArray
&
ishp
,
TensorShapeArray
&
oshp
)
override
{
...
...
@@ -617,35 +543,35 @@ CollectiveComm::CollectiveComm(
m_key
(
key
),
m_dev_buffers
(
dev_buffer_arr
),
m_disable
{
disable
}
{
for
(
auto
i
:
inputs
)
{
mgb_assert
(
i
->
comp_node
().
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"CollectiveComm currectly only supports CUDA"
);
}
for
(
auto
i
:
config
.
comp_node
())
{
mgb_assert
(
i
.
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
// add input
mgb_assert
(
inputs
.
size
()
<=
1
,
"one or zero input expected, got %zu"
,
inputs
.
size
());
if
(
inputs
.
size
()
>
0
)
{
mgb_assert
(
inputs
[
0
]
->
comp_node
().
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"CollectiveComm currectly only supports CUDA"
);
add_input
({
inputs
[
0
]});
}
CompNode
::
UnorderedSet
inp_cn
;
ThinHashSet
<
int
>
inp_dev
;
// add output
add_output
(
ssprintf
(
"%s:%s"
,
get_param_name
(
param
),
key
.
c_str
()))
;
for
(
auto
i
:
inputs
)
{
add_input
({
i
});
inp_cn
.
insert
(
i
->
comp_node
());
inp_dev
.
insert
(
CompNodeEnv
::
from_comp_node
(
i
->
comp_node
()).
cuda_env
().
device
);
// set comp node
const
auto
&
cns
=
config
.
comp_node
();
mgb_assert
(
cns
.
size
()
<=
1
,
"one or zero comp node expected, got %zu"
,
cns
.
size
());
if
(
cns
.
size
()
>
0
)
{
mgb_assert
(
cns
[
0
].
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"CollectiveComm currectly only supports CUDA"
);
output
(
0
)
->
comp_node
(
cns
[
0
]);
}
else
{
output
(
0
)
->
comp_node
(
inputs
[
0
]
->
comp_node
());
}
mgb_assert
(
inp_dev
.
size
()
==
inputs
.
size
(),
"CollectiveComm inputs should not contain duplicated input device"
);
ModeTrait
::
from_mode
(
param
.
mode
).
add_output_var
(
this
,
inp_cn
);
// set debug flag
const
char
*
c_debug
=
MGB_GETENV
(
"MGE_MM_OPR_DEBUG"
);
if
(
c_debug
!=
nullptr
and
strcmp
(
c_debug
,
"1"
)
==
0
)
{
m_debug_mode
=
true
;
}
// deduplication
add_equivalence_component
<
PODHash
<
Param
>>
(
&
m_param
);
add_equivalence_component
<
PODHash
<
size_t
>>
(
&
m_nr_devices
);
m_hash
=
XXHash
{}.
update
(
key
.
data
(),
key
.
size
()
*
sizeof
(
char
)).
digest
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录