Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ee2e2b3c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ee2e2b3c
编写于
9月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/gopt): fix nchwxx optpass of no handle conv_bias opr which with no bias
GitOrigin-RevId: b2b053add464540c22a61e72f078950c18bf92b0
上级
59a9275c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
106 addition
and
48 deletion
+106
-48
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+104
-44
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+2
-4
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
ee2e2b3c
...
...
@@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid(
auto
&
src_node
=
new_inp
[
0
];
auto
&
filter_node
=
new_inp
[
1
];
auto
dst_node
=
opr
.
output
(
0
);
if
(
filter_node
->
shape
().
ndim
!=
4
)
{
//! already transformed or have fuse Z
if
(
filter_node
->
shape
().
ndim
!=
4
||
new_inp
.
size
()
==
4
)
{
return
false
;
}
megdnn
::
ConvolutionBase
<
megdnn
::
param
::
Convolution
>::
CanonizedFilterMeta
fm
;
...
...
@@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid(
megdnn
::
ConvBiasForward
::
BiasMode
bias_mode
=
megdnn
::
ConvBiasForward
::
BiasMode
::
NO_BIAS
;
if
(
std
::
is_same
<
OprType
,
opr
::
ConvBiasForward
>::
value
)
{
if
(
std
::
is_same
<
OprType
,
opr
::
ConvBiasForward
>::
value
&&
new_inp
.
size
()
>
2
)
{
TensorShape
bias_shape
=
new_inp
[
2
]
->
shape
();
if
(
bias_shape
.
ndim
==
5
)
{
bias_shape
=
nchwxx_shape_2_nchw_shape
(
bias_shape
);
...
...
@@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
pack_c_size
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
<=
3
,
"nchwxx does not support conv_bias fuse Z right now"
);
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
mgb_assert
(
conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
,
...
...
@@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
temp_inp
[
0
]
=
new_src
.
node
();
}
//! the bias is nchwxx
if
(
temp_inp
[
2
]
->
shape
().
ndim
==
5
)
{
if
(
new_inp
.
size
()
>
2
&&
temp_inp
[
2
]
->
shape
().
ndim
==
5
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchw_mode
);
temp_inp
[
2
]
=
new_bias
.
node
();
...
...
@@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return
new_opr
;
}
else
if
(
is_trans
.
first
==
TransType
::
TRANS_PURE_NCHWXX
)
{
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
n
ew_inp
[
2
]
;
*
conv_bias_bias
=
n
ullptr
;
//! filter trans to nchwxx mode
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
...
...
@@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
src_to_nchwxx_mode
);
conv_bias_src
=
new_src
.
node
();
}
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchwxx_mode
);
conv_bias_bias
=
new_bias
.
node
();
//! bias trans to nchwxx mode
if
(
new_inp
.
size
()
>
2
)
{
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchwxx_mode
);
conv_bias_bias
=
new_bias
.
node
();
}
else
{
mgb_assert
(
new_inp
[
2
]
->
shape
().
ndim
==
5
);
conv_bias_bias
=
new_inp
[
2
];
}
}
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_bias_format
;
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
5
&&
conv_bias_filter
->
shape
().
ndim
>=
6
,
"The conv_bias src dim is not trans to nchwxx"
);
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
SymbolVar
new_conv_bias_opr
;
if
(
conv_bias_bias
)
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
else
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchwxx"
);
...
...
@@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
}
else
{
mgb_assert
(
is_trans
.
first
==
TransType
::
TRANS_HYBIRD_NCHWXX
);
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
n
ew_inp
[
2
]
;
*
conv_bias_bias
=
n
ullptr
;
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
second
);
conv_bias_filter
=
new_filter
.
node
();
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchwxx_mode
);
conv_bias_bias
=
new_bias
.
node
();
if
(
new_inp
.
size
()
>
2
)
{
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
src_to_nchwxx_mode
);
conv_bias_bias
=
new_bias
.
node
();
}
else
{
mgb_assert
(
new_inp
[
2
]
->
shape
().
ndim
==
5
);
conv_bias_bias
=
new_inp
[
2
];
}
}
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
4
&&
conv_bias_filter
->
shape
().
ndim
==
5
);
mgb_assert
((
conv_bias_bias
->
shape
().
ndim
==
5
)
||
conv_bias_bias
->
shape
().
is_scalar
());
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_bias_format
;
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
SymbolVar
new_conv_bias_opr
;
if
(
conv_bias_bias
)
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
else
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchwxx"
);
...
...
@@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpAffineForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Reshape
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
AxisAddRemove
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Argmax
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
Broadcast
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
ImmutableTensor
::
typeinfo
()]
=
relayout_inp_to_nchw
;
}
std
::
unique_ptr
<
EnableNchwxxPass
>
EnableNchwxxPass
::
make_nchwxx_converter
(
...
...
@@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
mgb_assert
(
opr
->
input
().
size
()
<=
3
,
"nchwxx-dot does not support conv_bias fuse Z right now"
);
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
mgb_assert
(
conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
,
...
...
@@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
}
//! the bias is nchwxx
if
(
temp_inp
[
2
]
->
shape
().
ndim
==
5
)
{
if
(
new_inp
.
size
()
>
2
&&
temp_inp
[
2
]
->
shape
().
ndim
==
5
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW4_TO_NCHW
);
temp_inp
[
2
]
=
new_bias
.
node
();
...
...
@@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
return
new_opr
;
}
else
if
(
is_trans
.
trans_type
==
TransType
::
TRANS_PURE_NCHWXX
)
{
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
n
ew_inp
[
2
]
;
*
conv_bias_bias
=
n
ullptr
;
//! filter trans to nchwxx mode
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
...
...
@@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
new_inp
[
0
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_src
=
new_src
.
node
();
}
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
//! bias trans to nchwxx mode
if
(
new_inp
.
size
()
>
2
)
{
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
}
else
{
mgb_assert
(
new_inp
[
2
]
->
shape
().
ndim
==
5
);
conv_bias_bias
=
new_inp
[
2
];
}
}
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
is_trans
.
conv_format
;
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
5
&&
conv_bias_filter
->
shape
().
ndim
>=
6
,
"The conv_bias src dim is not trans to nchwxx"
);
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
SymbolVar
new_conv_bias_opr
;
if
(
conv_bias_bias
)
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
else
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchwxx"
);
...
...
@@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
}
else
{
mgb_assert
(
is_trans
.
trans_type
==
TransType
::
TRANS_HYBIRD_NCHWXX
);
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
n
ew_inp
[
2
]
;
*
conv_bias_bias
=
n
ullptr
;
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
relayout_mod
);
conv_bias_filter
=
new_filter
.
node
();
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
if
(
new_inp
.
size
()
>
2
)
{
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
}
else
{
mgb_assert
(
new_inp
[
2
]
->
shape
().
ndim
==
5
);
conv_bias_bias
=
new_inp
[
2
];
}
}
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
4
&&
conv_bias_filter
->
shape
().
ndim
==
5
);
mgb_assert
((
conv_bias_bias
->
shape
().
ndim
==
5
)
||
conv_bias_bias
->
shape
().
is_scalar
());
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
is_trans
.
conv_format
;
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
SymbolVar
new_conv_bias_opr
;
if
(
conv_bias_bias
)
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
else
{
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
}
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchwxx"
);
...
...
src/gopt/test/inference.cpp
浏览文件 @
ee2e2b3c
...
...
@@ -3009,9 +3009,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! no supported hybrid nchw44
opr
::
ConvBias
::
Param
param_conv_bias_pad0
;
param_conv_bias_pad0
.
pad_h
=
param_conv_bias_pad0
.
pad_w
=
0
;
auto
b1
=
mkcvar
(
"b1"
,
{
1
,
8
,
1
,
1
});
auto
w1_f1
=
mkcvar
(
"w1_1"
,
{
8
,
3
,
1
,
1
});
auto
conv1_f1
=
opr
::
ConvBias
::
make
(
x
,
w1_f1
,
b1
,
param_conv_bias_pad0
,
{},
auto
conv1_f1
=
opr
::
ConvBias
::
make
(
x
,
w1_f1
,
param_conv_bias_pad0
,
{},
OperatorNodeConfig
(
"conv1_f1"
));
auto
conv1_add
=
conv1_f1
*
conv1
;
...
...
@@ -3263,9 +3262,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
auto
w1_2
=
mkcvar_dtype
(
"w1_2"
,
{
8
,
8
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
));
auto
b1_2
=
mkcvar_dtype
(
"b1_2"
,
{
1
,
8
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
auto
conv_1_2
=
opr
::
ConvBias
::
make
(
conv_1_q8
,
w1_2
,
b1_2
,
param_conv_bias
,
{},
conv_1_q8
,
w1_2
,
param_conv_bias
,
{},
OperatorNodeConfig
{
"conv_1_2"
,
cn
,
dtype
::
QuantizedS8
{
6.25
f
}});
auto
conv_1_2_fp32
=
opr
::
TypeCvt
::
make
(
conv_1_2
,
dtype
::
Float32
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录