Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b82e8f00
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b82e8f00
编写于
9月 16, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(gopt): refact the padding channel opt pass
GitOrigin-RevId: ee3f55aa66f21fe2d4a042298aafe4a0a02915f7
上级
f444d4fe
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
327 addition
and
307 deletion
+327
-307
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-1
src/gopt/impl/padding_channel.cpp
src/gopt/impl/padding_channel.cpp
+286
-302
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+30
-0
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+9
-4
未找到文件。
src/gopt/impl/framework.cpp
浏览文件 @
b82e8f00
...
...
@@ -783,7 +783,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb
(
nchw64
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
PaddingChannelPass
>
();
add_pass
(
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
));
add_pass
<
FuseConvBiasZPass
>
();
add_pass
(
EnableNCHW64Pass
::
make_nchw64_converter
());
add_pass
<
ShuffleShuffleRemovePass
>
();
...
...
src/gopt/impl/padding_channel.cpp
浏览文件 @
b82e8f00
此差异已折叠。
点击以展开。
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
b82e8f00
...
...
@@ -509,8 +509,38 @@ public:
*/
class
PaddingChannelPass
final
:
public
Pass
{
public:
using
ChannelAlignmentMap
=
ThinHashMap
<
DTypeEnum
,
std
::
function
<
size_t
(
size_t
,
bool
)
>>
;
using
LayoutTrans
=
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
;
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
void
fill_opr_convert_fun
(
LayoutTrans
layout_trans
);
using
ReplaceFuncs
=
ThinHashMap
<
Typeinfo
*
,
thin_function
<
OperatorNodeBase
*
(
OperatorNodeBase
*
,
const
VarNodeArray
&
)
>>
;
//! make channel padding opt pass with given tensor format
static
std
::
unique_ptr
<
PaddingChannelPass
>
make
(
LayoutTrans
layout_transform
);
private:
VarNode
*
extract_subtensor
(
VarNode
*
inp
,
const
TensorShape
&
orig_shape
)
const
;
VarNode
*
pad_in_channels
(
VarNode
*
inp
,
size_t
pad_channels
);
VarNode
*
pad_out_channels
(
VarNode
*
inp
,
size_t
pad_channels
);
OperatorNodeBase
*
padding_policy
(
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
);
void
add_convbias_replace_func
(
LayoutTrans
layout_transform
);
void
add_conv_backward_data_replace_func
(
LayoutTrans
layout_transform
);
void
add_format_aware_opr_replace_func
(
LayoutTrans
layout_transform
);
void
add_elemwise_like_opr_replace_func
(
LayoutTrans
layout_transform
);
void
add_nonpadding_oprs_replace_func
(
LayoutTrans
layout_transform
);
ChannelAlignmentMap
m_alignment_map
;
ReplaceFuncs
m_opr_replace_funcs
;
mutable
ThinHashSet
<
OperatorNodeBase
*>
m_padding_oprs
;
};
/*!
...
...
src/gopt/test/inference.cpp
浏览文件 @
b82e8f00
#include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
...
...
@@ -5037,7 +5038,8 @@ TEST(TestGoptInference, PaddingChannels) {
SymbolVar
y3_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y3
}})
.
endpoint_vars
(),
y3_pad
);
...
...
@@ -5101,7 +5103,8 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) {
SymbolVar
y2_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y2
}})
.
endpoint_vars
(),
y2_pad
);
...
...
@@ -5166,7 +5169,8 @@ TEST(TestGoptInference, PaddingChannelsWithPooling) {
SymbolVar
y1_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y1
}})
.
endpoint_vars
(),
y1_pad
);
...
...
@@ -5232,7 +5236,8 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
SymbolVar
y1_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
(
gopt
::
PaddingChannelPass
::
make
(
cg
::
GraphCommonOptimizeOptions
::
LayoutTransform
::
NCHW64
))
.
apply
({{
y1
}})
.
endpoint_vars
(),
y1_pad
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录