Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
30b3d3aa
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
30b3d3aa
编写于
5月 28, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/gopt): add convolution nchw44-dot format gopt
GitOrigin-RevId: e8e1e9637944ead470ebe4e2b697ddf7d437aaba
上级
48d1ac14
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
437 addition
and
41 deletion
+437
-41
dnn/src/naive/convolution/helper.h
dnn/src/naive/convolution/helper.h
+1
-2
python_module/megengine/_internal/__init__.py
python_module/megengine/_internal/__init__.py
+4
-0
python_module/src/swig/misc.i
python_module/src/swig/misc.i
+1
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+7
-0
src/core/include/megbrain/graph/cg.h
src/core/include/megbrain/graph/cg.h
+10
-8
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+2
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+302
-23
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+26
-6
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+84
-2
未找到文件。
dnn/src/naive/convolution/helper.h
浏览文件 @
30b3d3aa
...
...
@@ -287,8 +287,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
}
}
else
if
(
filter_meta
.
format
==
Format
::
NCHW44
||
filter_meta
.
format
==
Format
::
NCHW44_DOT
)
{
if
(
filter_meta
.
format
==
Format
::
NCHW44
&&
!
is_output
&&
src
.
layout
.
ndim
==
4
)
{
if
(
!
is_output
&&
src
.
layout
.
ndim
==
4
)
{
return
n
*
layout
.
stride
[
0
]
+
c
*
layout
.
stride
[
1
]
+
h
*
layout
.
stride
[
2
]
+
w
*
layout
.
stride
[
3
];
}
else
{
...
...
python_module/megengine/_internal/__init__.py
浏览文件 @
30b3d3aa
...
...
@@ -554,6 +554,7 @@ def optimize_for_inference(
use_nchw4
=
False
,
use_nchw88
=
False
,
use_nchw44
=
False
,
use_nchw44_dot
=
False
,
use_chwn4
=
False
):
"""optimize computing graph for inference
...
...
@@ -577,6 +578,8 @@ def optimize_for_inference(
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
times.
:param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is
optimized for inference in armv8.2
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
nvidia tensorcore.
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
...
...
@@ -605,6 +608,7 @@ def optimize_for_inference(
"use_nchw32"
:
"nchw32"
,
"use_nchw88"
:
"nchw88"
,
"use_nchw44"
:
"nchw44"
,
"use_nchw44_dot"
:
"nchw44_dot"
,
"use_chwn4"
:
"chwn4"
,
}.
items
():
if
settings
[
k
]:
...
...
python_module/src/swig/misc.i
浏览文件 @
30b3d3aa
...
...
@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions {
SET
(
nhwcd4
,
NHWCD4
)
;
SET
(
nchw88
,
NCHW88
)
;
SET
(
nchw44
,
NCHW44
)
;
SET
(
nchw44_dot
,
NCHW44_DOT
)
;
SET
(
nchw32
,
NCHW32
)
;
SET
(
chwn4
,
CHWN4
)
;
#
undef
SET
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
30b3d3aa
...
...
@@ -255,6 +255,7 @@ def optimize_for_inference(args, outputs):
'enable_nchw4'
:
'use_nchw4'
,
'enable_nchw88'
:
'use_nchw88'
,
'enable_nchw44'
:
'use_nchw44'
,
'enable_nchw44_dot'
:
'use_nchw44_dot'
,
'enable_nchw32'
:
'use_nchw32'
,
'enable_chwn4'
:
'use_chwn4'
,
'enable_fuse_conv_bias_nonlinearity'
:
'fuse_conv_bias_nonlinearity'
,
...
...
@@ -400,6 +401,12 @@ def main():
help
=
'transform the model format from NCHW to NCHW44 '
'for inference'
)
parser
.
add_argument
(
'--enable-nchw44-dot'
,
action
=
'store_true'
,
help
=
'transform the model format from NCHW to NCHW44_DOT '
'for optimizing armv8.2 dot in inference'
)
parser
.
add_argument
(
'--enable-nchw32'
,
action
=
'store_true'
,
...
...
src/core/include/megbrain/graph/cg.h
浏览文件 @
30b3d3aa
...
...
@@ -97,14 +97,15 @@ struct GraphCommonOptimizeOptions {
bool
fuse_conv_bias_with_z
=
false
;
enum
LayoutTransform
:
uint32_t
{
DEFAULT
,
NCHW4
,
///< compute using NCHW4 tensor format
NHWCD4
,
///< compute using NHWCD4 tensor format
NCHW88
,
///< compute using NCHW88 tensor format
NCHW44
,
///< compute using NCHW44 tensor format
NCHW32
,
///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4
,
///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
NCHW4
,
///< compute using NCHW4 tensor format
NHWCD4
,
///< compute using NHWCD4 tensor format
NCHW88
,
///< compute using NCHW88 tensor format
NCHW44
,
///< compute using NCHW44 tensor format
NCHW44_DOT
,
///< compute using NCHW44_DOT tensor format
NCHW32
,
///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4
,
///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
};
LayoutTransform
layout_transform
=
LayoutTransform
::
DEFAULT
;
...
...
@@ -142,6 +143,7 @@ struct GraphCommonOptimizeOptions {
SET
(
nhwcd4
,
NHWCD4
);
SET
(
nchw88
,
NCHW88
);
SET
(
nchw44
,
NCHW44
);
SET
(
nchw44_dot
,
NCHW44_DOT
);
SET
(
nchw32
,
NCHW32
);
SET
(
chwn4
,
CHWN4
);
#undef SET
...
...
src/gopt/impl/framework.cpp
浏览文件 @
30b3d3aa
...
...
@@ -738,6 +738,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb
(
nchw88
,
{
add_pass
(
EnableNchwxxPass
::
make_nchwxx_converter
(
8
));
});
cb
(
nchw44
,
{
add_pass
(
EnableNchwxxPass
::
make_nchwxx_converter
(
4
));
});
cb
(
nchw44_dot
,
{
add_pass
(
EnableNchw44DotPass
::
make_nchw44_dot_converter
());
});
cb
(
nchw32
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
add_pass
<
FuseConvBiasZPass
>
();
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
30b3d3aa
...
...
@@ -28,6 +28,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"
#if MGB_ENABLE_TENSOR_RT
...
...
@@ -59,19 +60,19 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
public
:
//! relayout type of this opr
enum
class
LayoutType
{
NCHW4_TO_NCHW32
,
//!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4
,
//!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4
,
//!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4
,
//!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4
,
//!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW
,
//!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88
,
//!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW
,
//!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE
,
//!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP
,
//!< group weight from nchw layout to
//!< nchw4 layout
NCHW4_TO_NCHW32
,
//!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4
,
//!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4
,
//!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4
,
//!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4
,
//!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW
,
//!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88
,
//!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW
,
//!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE
,
//!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP
,
//!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW88_DENSE
,
//!< weight from nchw layout to nchw88
//!< layout
...
...
@@ -92,6 +93,10 @@ public:
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44
,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
,
//!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
,
//!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
};
RelayoutPlaceholder
(
VarNode
*
src_var
,
LayoutType
layout_type
);
...
...
@@ -268,7 +273,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
3
]
=
inp_shape
[
1
];
dst
[
4
]
=
8
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DENSE
)
{
WEIGHT_NCHW_TO_NCHW44_DENSE
||
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
0
]
%
4
==
0
&&
inp_shape
[
1
]
%
4
==
0
);
dst
.
ndim
=
6
;
...
...
@@ -279,7 +286,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
4
]
=
4
;
dst
[
5
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_GROUP
)
{
WEIGHT_NCHW_TO_NCHW44_GROUP
||
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
1
]
%
4
==
0
&&
inp_shape
[
2
]
%
4
==
0
);
dst
.
ndim
=
7
;
...
...
@@ -646,6 +655,42 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
cv
(
4
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
),
sub
(
3
)},
0
),
tshp1
=
opr
::
Concat
::
make
(
{
sub
(
0
)
/
4
,
sub
(
1
)
/
4
,
sub
(
2
),
sub
(
3
),
cv
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
2
,
4
,
5
,
1
,
3
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
/
4
,
cv
(
4
),
sub
(
2
)
/
4
,
cv
(
4
),
sub
(
3
),
sub
(
4
)},
0
),
tshp1
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
1
)
/
4
,
sub
(
2
)
/
4
,
sub
(
3
),
sub
(
4
),
cv
(
4
),
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
1
,
3
,
5
,
6
,
2
,
4
});
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
reformat
,
&
rewriter
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -1601,12 +1646,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return
new_var
;
}
std
::
unique_ptr
<
EnableNchwxxPass
>
EnableNchwxxPass
::
make_nchwxx_converter
(
size_t
pack_c_size
)
{
auto
ret
=
std
::
make_unique
<
EnableNchwxxPass
>
(
pack_c_size
);
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
void
EnableNchwxxPass
::
fill_opr_convert_fun
(
size_t
pack_c_size
){
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
using
TestFilterResult
=
std
::
pair
<
TransType
,
RelayoutMode
>
;
RelayoutMode
weight_to_nchwxx_mode_dense
=
...
...
@@ -1954,8 +1994,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
return
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
};
ret
->
set_name
(
convter_pass_name
);
auto
&&
replace_func
=
ret
->
m_opr_replace_func
;
auto
&&
replace_func
=
m_opr_replace_func
;
//! supportted nchwxx
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
...
...
@@ -1978,6 +2017,246 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func
[
opr
::
WarpPerspectiveForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
replace_func
[
opr
::
WarpAffineForward
::
typeinfo
()]
=
relayout_inp_to_nchw
;
}
std
::
unique_ptr
<
EnableNchwxxPass
>
EnableNchwxxPass
::
make_nchwxx_converter
(
size_t
pack_c_size
)
{
auto
ret
=
std
::
make_unique
<
EnableNchwxxPass
>
(
pack_c_size
);
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
std
::
string
convter_pass_name
=
"conv_format_nchw88"
;
if
(
pack_c_size
==
4
)
{
convter_pass_name
=
"conv_format_nchw44"
;
}
ret
->
fill_opr_convert_fun
(
pack_c_size
);
ret
->
set_name
(
convter_pass_name
);
return
ret
;
}
/* ================ EnableNchw44DotPass =============== */
VarNode
*
EnableNchw44DotPass
::
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
{
if
(
!
orig_var
->
shape
().
eq_shape
(
new_var
->
shape
()))
{
return
RelayoutPlaceholder
::
make
(
new_var
,
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW
)
.
node
();
}
return
new_var
;
}
std
::
unique_ptr
<
EnableNchw44DotPass
>
EnableNchw44DotPass
::
make_nchw44_dot_converter
()
{
auto
ret
=
std
::
make_unique
<
EnableNchw44DotPass
>
();
ret
->
set_var_replace_check_flag
(
VarReplaceCheckFlag
::
NOCHECK
);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
using
RelayoutMode
=
RelayoutPlaceholder
::
LayoutType
;
using
TestTransResult
=
std
::
pair
<
TransType
,
RelayoutMode
>
;
megdnn
::
param
::
ConvolutionV0
::
Format
conv_dot_format
=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
constexpr
size_t
pack_c_size
=
4
_z
;
auto
test_trans_nchw44_dot
=
[](
const
megdnn
::
param
::
Convolution
::
Sparse
conv_mode
,
const
VarNode
*
filter
)
->
TestTransResult
{
TestTransResult
ret
{
TransType
::
TRANS_NONE
,
{}};
if
(
conv_mode
==
megdnn
::
param
::
Convolution
::
Sparse
::
DENSE
)
{
size_t
IC
=
filter
->
shape
()[
1
];
size_t
OC
=
filter
->
shape
()[
0
];
if
((
IC
%
pack_c_size
==
0
)
&&
(
OC
%
pack_c_size
==
0
))
{
ret
.
first
=
TransType
::
TRANS_PURE_NCHWXX
;
ret
.
second
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE
;
}
else
if
(
IC
<
pack_c_size
&&
OC
%
pack_c_size
==
0
)
{
ret
.
first
=
TransType
::
TRANS_HYBIRD_NCHWXX
;
ret
.
second
=
RelayoutMode
::
WEIGHT_HYBIRD_NCHW_NCHW44
;
}
}
else
{
mgb_assert
(
conv_mode
==
megdnn
::
param
::
Convolution
::
Sparse
::
GROUP
);
size_t
group
=
filter
->
shape
()[
0
];
size_t
ocpg
=
filter
->
shape
()[
1
];
size_t
icpg
=
filter
->
shape
()[
2
];
if
(
icpg
==
1
&&
ocpg
==
1
&&
(
group
%
pack_c_size
==
0
))
{
ret
.
first
=
TransType
::
TRANS_NONE
;
}
else
if
((
icpg
%
pack_c_size
==
0
)
&&
(
ocpg
%
pack_c_size
==
0
))
{
ret
.
first
=
TransType
::
TRANS_PURE_NCHWXX
;
ret
.
second
=
RelayoutMode
::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP
;
}
}
return
ret
;
};
auto
replace_conv_opr
=
[
test_trans_nchw44_dot
,
conv_dot_format
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_opr
=
opr
->
cast_final_safe
<
opr
::
ConvolutionForward
>
();
mgb_assert
(
conv_opr
.
param
().
format
==
megdnn
::
param
::
Convolution
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT"
);
auto
is_trans
=
test_trans_nchw44_dot
(
conv_opr
.
param
().
sparse
,
new_inp
[
1
]);
//! can not trans to nchwxx
if
(
is_trans
.
first
==
TransType
::
TRANS_NONE
)
{
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
"The origin filter is not NCHW mode"
);
VarNodeArray
temp_inp
=
new_inp
;
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if
(
temp_inp
[
0
]
->
shape
().
ndim
==
5
)
{
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
RelayoutMode
::
NCHW4_TO_NCHW
);
temp_inp
[
0
]
=
new_src
.
node
();
}
auto
new_opr
=
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
return
new_opr
;
}
else
if
(
is_trans
.
first
==
TransType
::
TRANS_PURE_NCHWXX
)
{
//! filter trans to nchwxx mode
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
"The origin filter is not NCHW mode"
);
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
second
);
conv_filter
=
new_filter
.
node
();
//! src trans to nchwxx mode
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_src
=
new_src
.
node
();
}
auto
new_param
=
conv_opr
.
param
();
new_param
.
format
=
conv_dot_format
;
mgb_assert
(
conv_src
->
shape
().
ndim
==
5
&&
conv_filter
->
shape
().
ndim
>=
6
,
"The conv src dim is not trans to nchwxx"
);
auto
new_conv_opr
=
opr
::
Convolution
::
make
(
conv_src
,
conv_filter
,
new_param
,
conv_opr
.
execution_policy
(),
conv_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchwxx"
);
return
new_opr
;
}
else
{
mgb_assert
(
is_trans
.
first
==
TransType
::
TRANS_HYBIRD_NCHWXX
);
VarNode
*
conv_src
=
new_inp
[
0
],
*
conv_filter
=
new_inp
[
1
];
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
second
);
conv_filter
=
new_filter
.
node
();
mgb_assert
(
conv_src
->
shape
().
ndim
==
4
&&
conv_filter
->
shape
().
ndim
==
5
,
"The src and filter is OK"
);
auto
new_param
=
conv_opr
.
param
();
new_param
.
format
=
conv_dot_format
;
auto
new_conv_opr
=
opr
::
Convolution
::
make
(
conv_src
,
conv_filter
,
new_param
,
conv_opr
.
execution_policy
(),
conv_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchwxx"
);
return
new_opr
;
}
};
auto
replace_conv_bias_opr
=
[
test_trans_nchw44_dot
,
conv_dot_format
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
auto
&
conv_bias_opr
=
opr
->
cast_final_safe
<
opr
::
ConvBiasForward
>
();
mgb_assert
(
conv_bias_opr
.
param
().
format
==
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
,
"ConvertFormat Pass only support converting NCHW to NCHWXX"
);
auto
is_trans
=
test_trans_nchw44_dot
(
conv_bias_opr
.
param
().
sparse
,
new_inp
[
1
]);
//! can not trans to nchwxx
if
(
is_trans
.
first
==
TransType
::
TRANS_NONE
)
{
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
"The origin filter is not NCHW mode"
);
VarNodeArray
temp_inp
=
new_inp
;
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if
(
temp_inp
[
0
]
->
shape
().
ndim
==
5
)
{
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
RelayoutMode
::
NCHW4_TO_NCHW
);
temp_inp
[
0
]
=
new_src
.
node
();
}
//! the bias is nchwxx
if
(
temp_inp
[
2
]
->
shape
().
ndim
==
5
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW4_TO_NCHW
);
temp_inp
[
2
]
=
new_bias
.
node
();
}
auto
new_opr
=
serialization
::
copy_opr_shallow
(
*
opr
,
temp_inp
,
opr
->
config
());
return
new_opr
;
}
else
if
(
is_trans
.
first
==
TransType
::
TRANS_PURE_NCHWXX
)
{
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
new_inp
[
2
];
//! filter trans to nchwxx mode
mgb_assert
(
new_inp
[
1
]
->
shape
().
ndim
==
4
||
new_inp
[
1
]
->
shape
().
ndim
==
5
,
"The origin filter is not NCHW mode"
);
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
second
);
conv_bias_filter
=
new_filter
.
node
();
//! src trans to nchwxx mode
if
(
new_inp
[
0
]
->
shape
().
ndim
!=
5
)
{
mgb_assert
(
new_inp
[
0
]
->
shape
().
ndim
==
4
);
auto
new_src
=
RelayoutPlaceholder
::
make
(
new_inp
[
0
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_src
=
new_src
.
node
();
}
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
}
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_dot_format
;
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
5
&&
conv_bias_filter
->
shape
().
ndim
>=
6
,
"The conv_bias src dim is not trans to nchwxx"
);
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv_bias dst dim is not trans to nchwxx"
);
return
new_opr
;
}
else
{
mgb_assert
(
is_trans
.
first
==
TransType
::
TRANS_HYBIRD_NCHWXX
);
VarNode
*
conv_bias_src
=
new_inp
[
0
],
*
conv_bias_filter
=
new_inp
[
1
],
*
conv_bias_bias
=
new_inp
[
2
];
auto
new_filter
=
RelayoutPlaceholder
::
make
(
new_inp
[
1
],
is_trans
.
second
);
conv_bias_filter
=
new_filter
.
node
();
//! bias trans to nchwxx mode, bias may be scale
if
(
new_inp
[
2
]
->
shape
().
ndim
==
4
)
{
auto
new_bias
=
RelayoutPlaceholder
::
make
(
new_inp
[
2
],
RelayoutMode
::
NCHW_TO_NCHW4
);
conv_bias_bias
=
new_bias
.
node
();
}
mgb_assert
(
conv_bias_src
->
shape
().
ndim
==
4
&&
conv_bias_filter
->
shape
().
ndim
==
5
);
mgb_assert
((
conv_bias_bias
->
shape
().
ndim
==
5
)
||
conv_bias_bias
->
shape
().
is_scalar
());
auto
new_param
=
conv_bias_opr
.
param
();
new_param
.
format
=
conv_dot_format
;
auto
new_conv_bias_opr
=
opr
::
ConvBias
::
make
(
conv_bias_src
,
conv_bias_filter
,
conv_bias_bias
,
new_param
,
conv_bias_opr
.
execution_policy
(),
conv_bias_opr
.
config
());
OperatorNodeBase
*
new_opr
=
new_conv_bias_opr
.
node
()
->
owner_opr
();
mgb_assert
(
new_conv_bias_opr
.
shape
().
ndim
==
5
,
"The conv dst dim is not trans to nchwxx"
);
return
new_opr
;
}
};
ret
->
fill_opr_convert_fun
(
4
);
auto
&&
replace_func
=
ret
->
m_opr_replace_func
;
//! supportted nchwxx
replace_func
[
opr
::
Convolution
::
typeinfo
()]
=
replace_conv_opr
;
replace_func
[
opr
::
ConvBias
::
typeinfo
()]
=
replace_conv_bias_opr
;
return
ret
;
}
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
30b3d3aa
...
...
@@ -236,8 +236,10 @@ namespace gopt {
VarNode
*
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
override
;
public:
const
char
*
name
()
const
override
{
return
mgb_cstr_log
(
"tensor_format_nchw4"
);
}
const
char
*
name
()
const
override
{
return
mgb_cstr_log
(
"tensor_format_nchw4"
);
}
//! make nchw -> nchw4 converter opt pass
static
std
::
unique_ptr
<
EnableNCHW4Pass
>
make_nchw4_converter
();
};
...
...
@@ -246,30 +248,48 @@ namespace gopt {
* \brief convert tensor format to nchwxx to speed up inference on certain
* devices
*/
class
EnableNchwxxPass
final
:
public
TensorReformatPass
{
class
EnableNchwxxPass
:
public
TensorReformatPass
{
std
::
string
m_name
=
"tensor_format_nchwxx"
;
size_t
m_pack_c_size
;
VarNode
*
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
override
;
public:
EnableNchwxxPass
(
size_t
pack_c_size
)
:
m_pack_c_size
(
pack_c_size
)
{}
//! the flag for conv to transform to nchwxx
enum
class
TransType
{
TRANS_PURE_NCHWXX
,
//!< weight and src all trans to nchwxx
TRANS_HYBIRD_NCHWXX
,
//!< input is nchw, output is nchwxx
TRANS_NONE
,
//!< no need trans
};
public:
EnableNchwxxPass
(
size_t
pack_c_size
)
:
m_pack_c_size
(
pack_c_size
)
{}
const
char
*
name
()
const
override
{
return
mgb_cstr_log
(
m_name
.
c_str
());
}
void
set_name
(
std
::
string
in_name
)
{
m_name
=
in_name
;
}
void
fill_opr_convert_fun
(
size_t
pack_c_size
);
//! make nchw -> nchwxx converter opt pass, pack_c_size is the x, like
//! 4,8,16
static
std
::
unique_ptr
<
EnableNchwxxPass
>
make_nchwxx_converter
(
size_t
pack_c_size
);
};
/*!
* \brief convert tensor format from nchw44 to nchw44_dot to speed up
* inference on armv8.2
*/
class
EnableNchw44DotPass
final
:
public
EnableNchwxxPass
{
std
::
string
m_name
=
"tensor_format_nchw44_dot"
;
VarNode
*
on_graph_endpoint_var
(
VarNode
*
new_var
,
VarNode
*
orig_var
)
const
override
;
public:
EnableNchw44DotPass
()
:
EnableNchwxxPass
(
4
)
{}
//! make nchw44 -> nchw44_dot converter opt pass
static
std
::
unique_ptr
<
EnableNchw44DotPass
>
make_nchw44_dot_converter
();
};
struct
OptimizeForInferenceOptions
:
cg
::
GraphCommonOptimizeOptions
{};
/*!
...
...
src/gopt/test/inference.cpp
浏览文件 @
30b3d3aa
...
...
@@ -2356,7 +2356,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
.
rename
(
name
),
dtype
);
};
auto
x
=
mkvar
(
"x"
,
{
2
,
4
,
16
,
16
},
dtype
::
QuantizedS8
(
2.5
f
));
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
...
...
@@ -2376,7 +2376,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
b2
=
mkcvar
(
"b2"
,
{
1
,
8
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
auto
conv2
=
opr
::
ConvBiasForward
::
make
(
conv1
,
w2
,
b2
,
param_conv_bias
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
{
2.5
f
}});
auto
y
=
opr
::
TypeCvt
::
make
(
conv2
,
dtype
::
Float32
());
SymbolVar
y_opt
;
...
...
@@ -2617,4 +2617,86 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
}
TEST
(
TestGoptInference
,
ConvertFormatNCHW44_DOT
)
{
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"cpu0"
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
);
};
auto
host_x
=
gen
({
2
,
3
,
16
,
16
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
);
//!Hybrid nchw88 mode
opr
::
Convolution
::
Param
param_conv
;
param_conv
.
pad_h
=
param_conv
.
pad_w
=
1
;
auto
w1
=
mkcvar
(
"w1"
,
{
8
,
3
,
3
,
3
}),
conv1
=
opr
::
Convolution
::
make
(
x
,
w1
,
param_conv
);
//!channel wise
opr
::
ConvBias
::
Param
param_conv_bias
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
w2
=
mkcvar
(
"w2"
,
{
8
,
1
,
1
,
3
,
3
}),
b2
=
mkcvar
(
"b2"
,
{
1
,
8
,
1
,
1
}),
conv2
=
opr
::
ConvBias
::
make
(
conv1
,
w2
,
b2
,
param_conv_bias
);
//! group
auto
w3
=
mkcvar
(
"w3"
,
{
2
,
4
,
4
,
3
,
3
}),
b3
=
mkcvar
(
"b3"
,
{
1
,
8
,
1
,
1
}),
conv3
=
opr
::
ConvBias
::
make
(
conv2
,
w3
,
b3
,
param_conv_bias
);
auto
shape_of
=
opr
::
GetVarShape
::
make
(
conv3
);
auto
subtensor
=
opr
::
Subtensor
::
make
(
shape_of
,
{
opr
::
Subtensor
::
AxisIndexer
::
make_interval
(
0
,
x
.
make_scalar
(
2
),
None
,
x
.
make_scalar
(
1
))});
opr
::
Resize
::
Param
param_resize
;
param_resize
.
format
=
opr
::
Resize
::
Param
::
Format
::
NCHW
;
auto
resize
=
opr
::
ResizeForward
::
make
(
conv3
,
subtensor
*
2
,
param_resize
);
auto
mat
=
mkcvar
(
"mat"
,
{
2
,
3
,
3
}),
warp
=
opr
::
WarpPerspectiveForward
::
make
(
resize
,
mat
,
nullptr
,
cg
::
var_from_tensor_shape
(
x
,
{
4
,
4
}));
auto
b
=
mkvar
(
"b"
,
{
1
,
8
,
1
,
1
}),
elem
=
opr
::
Elemwise
::
make
({
warp
+
b
},
opr
::
Elemwise
::
Param
::
Mode
::
RELU
);
//! Dense
param_conv_bias
.
sparse
=
opr
::
ConvBias
::
Param
::
Sparse
::
DENSE
;
param_conv_bias
.
pad_h
=
param_conv_bias
.
pad_w
=
1
;
auto
w4
=
mkcvar
(
"w4"
,
{
4
,
8
,
3
,
3
}),
b4
=
mkcvar
(
"b4"
,
{
1
,
4
,
1
,
1
}),
conv4
=
opr
::
ConvBias
::
make
(
elem
,
w4
,
b4
,
param_conv_bias
);
auto
w5
=
mkcvar
(
"w5"
,
{
6
,
4
,
3
,
3
}),
b5
=
mkcvar
(
"b5"
,
{
1
,
6
,
1
,
1
}),
conv5
=
opr
::
ConvBias
::
make
(
conv4
,
w5
,
b5
,
param_conv_bias
);
auto
w6
=
mkcvar
(
"w6"
,
{
4
,
6
,
3
,
3
}),
b6
=
mkcvar
(
"b6"
,
{
1
,
4
,
1
,
1
}),
y
=
opr
::
ConvBias
::
make
(
conv5
,
w6
,
b6
,
param_conv_bias
);
SymbolVar
y_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_nchw44_dot
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
ASSERT_EQ
(
opr
::
ConvBias
::
Param
::
Format
::
NCHW44_DOT
,
find_opr
<
opr
::
Convolution
>
(
y_opt
).
param
().
format
);
ASSERT_EQ
(
opr
::
Convolution
::
Param
::
Format
::
NCHW
,
find_opr
<
opr
::
ConvBias
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.ConvertFormatNCHW44.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
*
host_x
=
*
gen
({
2
,
3
,
32
,
32
},
cn
);
func
->
execute
();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-1
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录