Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c20d4cc6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c20d4cc6
编写于
8月 27, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): fix opt pass nchw44 can not dump resnet
GitOrigin-RevId: 28e5c37f53349d482b191751923b5a4b05b0633d
上级
3dbac4f4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
24 addition
and
5 deletion
+24
-5
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+13
-1
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+11
-4
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
c20d4cc6
...
...
@@ -1815,6 +1815,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return
new_var
;
}
static
inline
TensorShape
nchwxx_shape_2_nchw_shape
(
const
TensorShape
&
origin_shape
)
{
mgb_assert
(
origin_shape
.
ndim
==
5
);
TensorShape
result
=
origin_shape
;
result
[
1
]
*=
result
[
4
];
result
.
ndim
=
4
;
return
result
;
}
template
<
typename
OprType
>
static
inline
bool
nchw_nchwxx_valid
(
const
OprType
&
opr
,
const
VarNodeArray
&
new_inp
,
const
size_t
pack_size
,
...
...
@@ -1847,7 +1856,10 @@ static inline bool nchw_nchwxx_valid(
megdnn
::
ConvBiasForward
::
BiasMode
bias_mode
=
megdnn
::
ConvBiasForward
::
BiasMode
::
NO_BIAS
;
if
(
std
::
is_same
<
OprType
,
opr
::
ConvBiasForward
>::
value
)
{
auto
&
bias_shape
=
new_inp
[
2
]
->
shape
();
TensorShape
bias_shape
=
new_inp
[
2
]
->
shape
();
if
(
bias_shape
.
ndim
==
5
)
{
bias_shape
=
nchwxx_shape_2_nchw_shape
(
bias_shape
);
}
if
(
bias_shape
.
ndim
==
0
)
{
bias_mode
=
megdnn
::
ConvBiasForward
::
BiasMode
::
NO_BIAS
;
}
else
if
(
bias_shape
.
eq_shape
(
dst_node
->
shape
()))
{
...
...
src/gopt/test/inference.cpp
浏览文件 @
c20d4cc6
...
...
@@ -3069,12 +3069,18 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! Dense
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
DENSE
;
auto
w4
=
mkcvar
(
"w4"
,
{
4
,
32
,
3
,
3
}),
b4
=
mkcvar
(
"b4"
,
{
1
,
4
,
1
,
1
}),
auto
w4
=
mkcvar
(
"w4"
,
{
16
,
32
,
3
,
3
}),
b4
=
mkcvar
(
"b4"
,
{
1
,
16
,
1
,
1
}),
conv4
=
opr
::
ConvBias
::
make
(
conv3_3
,
w4
,
b4
,
param_conv_bias
,
{},
OperatorNodeConfig
(
"conv4"
));
auto
w5
=
mkcvar
(
"w5"
,
{
6
,
4
,
3
,
3
}),
b5
=
mkcvar
(
"b5"
,
{
1
,
6
,
1
,
1
}),
conv5
=
opr
::
ConvBias
::
make
(
conv4
,
w5
,
b5
,
param_conv_bias
,
{},
auto
w4_1
=
mkcvar
(
"w4_1"
,
{
16
,
32
,
1
,
1
}),
b4_1
=
mkcvar
(
"b4_1"
,
{
2
,
16
,
4
,
4
}),
conv4_1
=
opr
::
ConvBias
::
make
(
conv3_3
,
w4_1
,
b4_1
,
param_conv_bias_pad0
,
{},
OperatorNodeConfig
(
"conv4_1"
));
auto
conv4_add
=
conv4
+
conv4_1
;
auto
w5
=
mkcvar
(
"w5"
,
{
6
,
16
,
3
,
3
}),
b5
=
mkcvar
(
"b5"
,
{
1
,
6
,
1
,
1
}),
conv5
=
opr
::
ConvBias
::
make
(
conv4_add
,
w5
,
b5
,
param_conv_bias
,
{},
OperatorNodeConfig
(
"conv5"
));
auto
w6
=
mkcvar
(
"w6"
,
{
4
,
6
,
3
,
3
}),
b6
=
mkcvar
(
"b6"
,
{
1
,
4
,
1
,
1
}),
y
=
opr
::
ConvBias
::
make
(
conv5
,
w6
,
b6
,
param_conv_bias
,
{},
...
...
@@ -3082,6 +3088,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
SymbolVar
y_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_fuse_conv_bias_nonlinearity
();
options
.
enable_nchw44
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录