Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
76b28408
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
76b28408
编写于
7月 22, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): add subgraph extractor
GitOrigin-RevId: 56fd701c2c86aaa34e08a01fa1faa75a7dc50000
上级
8a3eb05a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
141 addition
and
0 deletion
+141
-0
src/gopt/impl/subgraph_extractor.cpp
src/gopt/impl/subgraph_extractor.cpp
+101
-0
src/gopt/include/megbrain/gopt/subgraph_extractor.h
src/gopt/include/megbrain/gopt/subgraph_extractor.h
+40
-0
未找到文件。
src/gopt/impl/subgraph_extractor.cpp
0 → 100644
浏览文件 @
76b28408
/**
* \file src/gopt/impl/subgraph_extractor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/subgraph_extractor.h"
using
namespace
mgb
;
using
namespace
cg
;
using
namespace
gopt
;
/* ================== SubGraphExtractor =================*/
std
::
vector
<
InternalGraph
>
SubGraphExtractor
::
extract
(
const
SymbolVarArray
&
endpoint_vars
)
const
{
ThinHashMap
<
OperatorNodeBase
*
,
std
::
pair
<
OperatorNodeBase
*
,
int
>>
parent
;
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
)
>
union_find
;
auto
union_find
=
[
&
parent
,
&
union_find
](
OperatorNodeBase
*
o
)
{
if
(
parent
[
o
].
first
==
o
)
return
o
;
else
{
auto
p
=
union_find
(
parent
[
o
].
first
);
parent
[
o
].
first
=
p
;
return
p
;
}
};
auto
union_merge
=
[
&
parent
,
&
union_find
](
OperatorNodeBase
*
x
,
OperatorNodeBase
*
y
)
{
auto
root_x
=
union_find
(
x
),
root_y
=
union_find
(
y
);
if
(
root_x
!=
root_y
)
{
OperatorNodeBase
*
large
,
small
;
if
(
parent
[
root_x
].
second
<
parent
[
root_y
].
second
)
{
small
=
root_x
,
large
=
root_y
;
}
else
{
small
=
root_y
,
large
=
root_x
;
}
parent
[
small
].
first
=
large
;
if
(
parent
[
large
].
second
==
parent
[
small
].
second
)
{
parend
[
large
].
second
+=
1
;
}
}
};
std
::
vector
<
OperatorNodeBase
*>
topo
;
auto
cb
=
[
&
topo
](
OperatorNodeBase
*
opr
)
{
topo
.
push_back
(
opr
);
if
(
opr_list
.
count
(
opr
->
dyn_typeinfo
())
==
0
)
return
;
auto
find
=
parent
.
find
(
opr
);
if
(
find
==
parent
.
end
())
{
auto
insert
=
parent
.
insert
(
std
::
make_pair
(
opr
,
std
::
make_pair
(
opr
,
0
)));
find
=
insert
.
first
;
}
for
(
auto
&&
i
:
opr
->
input
())
{
auto
&&
o
=
i
->
owner_opr
();
if
(
opr_list
.
count
(
o
->
dyn_typeinfo
())
==
0
)
continue
;
union_merge
(
opr
,
o
);
}
};
cg
::
DepOprIter
iter
{
cb
};
for
(
const
auto
&
v
:
endpoint_vars
)
iter
.
add
(
v
.
node
()
->
owner_opr
());
std
::
vector
<
InternalGraph
>
partitions
;
ThinHashMap
<
OperatorNodeBase
*
,
InternalGraph
*>
roots
;
for
(
const
auto
&
opr
:
reverse_adaptor
(
topo
))
{
auto
root
=
union_find
(
opr
);
auto
find
=
roots
.
find
(
root
);
InternalGraph
*
internal_graph
=
nullptr
;
if
(
find
==
roots
.
end
())
{
partitions
.
emplace_back
(
InternalGraph
{});
auto
insert
=
roots
.
insert
(
std
::
make_pair
(
root
,
&
partitions
.
back
()));
internal_graph
=
insert
.
first
->
second
;
internal_graph
->
m_outputs
.
insert
(
opr
->
output
(
0
));
}
else
{
internal_graph
=
find
->
second
;
auto
erase
=
internal_graph
->
m_inputs
.
erase
(
opr
->
output
(
0
));
if
(
erase
>
0
)
{
internal_graph
->
m_internals
.
insert
(
opr
->
output
(
0
));
}
else
{
internal_graph
->
m_outputs
.
insert
(
opr
->
output
(
0
));
}
}
for
(
const
auto
&
i
:
opr
->
input
())
internal_graph
->
m_inputs
.
insert
(
i
);
}
return
partitions
;
}
/* ============= SubGraphExtractor =================*/
// vim: syntax=cpp.doxygen
src/gopt/include/megbrain/gopt/subgraph_extractor.h
0 → 100644
浏览文件 @
76b28408
/**
* \file src/gopt/include/megbrain/gopt/subgraph_extractor.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/graph.h"
namespace
mgb
{
namespace
gopt
{
struct
InternalGraph
{
ThinHashSet
<
VarNode
*>
m_internals
;
ThinHashSet
<
VarNode
*>
m_inputs
;
ThinHashSet
<
VarNode
*>
m_outputs
;
};
class
SubGraphExtractor
{
public:
using
OprList
=
ThinHashSet
<
Typeinfo
*>
;
SubGraphExtractor
(
OprList
opr_list
)
:
m_opr_list
{
opr_list
}
{};
std
::
vector
<
InternalGraph
>
extract
(
const
SymbolVarArray
&
endpoint_vars
)
const
;
private:
class
Impl
;
OprList
m_opr_list
;
};
}
// namespace gopt
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录