提交 a3560fa1 编写于 作者: M Megvii Engine Team

feat(gopt): add tranform to chwn4 to optimize_for_inference

GitOrigin-RevId: 4d1a9c6c8410904ea4da17a1bed2ad06ce369869
上级 1fb7d34f
...@@ -542,7 +542,8 @@ def optimize_for_inference( ...@@ -542,7 +542,8 @@ def optimize_for_inference(
use_nchw32=False, use_nchw32=False,
fuse_conv_bias_with_z=False, fuse_conv_bias_with_z=False,
use_nchw88=False, use_nchw88=False,
use_nchw44=False use_nchw44=False,
use_chwn4=False
): ):
"""optimize computing graph for inference """optimize computing graph for inference
...@@ -566,6 +567,8 @@ def optimize_for_inference( ...@@ -566,6 +567,8 @@ def optimize_for_inference(
times. times.
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
nvidia tensorcore. nvidia tensorcore.
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
nvidia tensorcore.
:return: list of transformed vars corresponding to given output vars :return: list of transformed vars corresponding to given output vars
...@@ -589,6 +592,7 @@ def optimize_for_inference( ...@@ -589,6 +592,7 @@ def optimize_for_inference(
"use_nchw32": "nchw2nchw32", "use_nchw32": "nchw2nchw32",
"use_nchw88": "nchw2nchw88", "use_nchw88": "nchw2nchw88",
"use_nchw44": "nchw2nchw44", "use_nchw44": "nchw2nchw44",
"use_chwn4": "nchw42chwn4",
}.items(): }.items():
if settings[k]: if settings[k]:
assert ( assert (
......
...@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { ...@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions {
SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw88, NCHW2NCHW88);
SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw44, NCHW2NCHW44);
SET(nchw2nchw32, NCHW2NCHW32); SET(nchw2nchw32, NCHW2NCHW32);
SET(nchw42chwn4, NCHW42CHWN4);
#undef SET #undef SET
}; };
......
...@@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs): ...@@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs):
'enable_hwcd4': 'use_nhwcd4', 'enable_hwcd4': 'use_nhwcd4',
'enable_nchw88': 'use_nchw88', 'enable_nchw88': 'use_nchw88',
'enable_nchw44': 'use_nchw44', 'enable_nchw44': 'use_nchw44',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
'enable_nchw32': 'use_nchw32', 'enable_nchw32': 'use_nchw32',
'enable_chwn4': 'use_chwn4',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
} }
kwargs = {} kwargs = {}
...@@ -398,6 +399,12 @@ def main(): ...@@ -398,6 +399,12 @@ def main():
help='transform the model format from NCHW4 to NCHW32 ' help='transform the model format from NCHW4 to NCHW32 '
'for inference on nvidia TensoCore' 'for inference on nvidia TensoCore'
) )
parser.add_argument(
'--enable-chwn4',
action='store_true',
help='transform the model format to CHWN4 '
'for inference, mainly used for nvidia tensorcore'
)
parser.add_argument( parser.add_argument(
'--enable-fuse-conv-bias-with-z', '--enable-fuse-conv-bias-with-z',
action='store_true', action='store_true',
......
...@@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options( ...@@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options(
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
} }
if (options->transform_nchw42chwn4()) {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
}
if (options->fuse_conv_bias_nonlinearity) { if (options->fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
......
...@@ -395,6 +395,8 @@ namespace gopt { ...@@ -395,6 +395,8 @@ namespace gopt {
NCHW2NCHW44, ///< compute using NCHW44 tensor format NCHW2NCHW44, ///< compute using NCHW44 tensor format
NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore ///< tensorcore
NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed
///< from NCHW4, mainly used for cuda
}; };
LayoutTransform layout_transform = LayoutTransform::DEFAULT; LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
...@@ -424,6 +426,7 @@ namespace gopt { ...@@ -424,6 +426,7 @@ namespace gopt {
SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw88, NCHW2NCHW88);
SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw44, NCHW2NCHW44);
SET(nchw2nchw32, NCHW2NCHW32); SET(nchw2nchw32, NCHW2NCHW32);
SET(nchw42chwn4, NCHW42CHWN4);
#undef SET #undef SET
}; };
......
...@@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) { ...@@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) {
y4 = opr::TypeCvt::make(y4, dtype::Float32()); y4 = opr::TypeCvt::make(y4, dtype::Float32());
SymbolVar y_opt; SymbolVar y_opt;
SymbolVar y_cudnn; SymbolVar y_cudnn;
unpack_vector( {
gopt::GraphOptimizer{} auto options = gopt::OptimizeForInferenceOptions{};
.add_pass<gopt::FuseConvBiasNonlinPass>() options.enable_nchw42chwn4();
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt);
.add_pass<gopt::FuseConvBiasZPass>() }
.apply({{y4}})
.endpoint_vars(),
y_opt);
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>() .add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>() .add_pass<gopt::FuseConvBiasZPass>()
...@@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { ...@@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) {
auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param);
SymbolVar y_opt; SymbolVar y_opt;
SymbolVar y_cudnn; SymbolVar y_cudnn;
unpack_vector(gopt::GraphOptimizer{} {
.add_pass<gopt::FuseConvBiasNonlinPass>() auto options = gopt::OptimizeForInferenceOptions{};
.add_pass<gopt::FuseConvBiasZPass>() options.enable_nchw42chwn4();
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt);
.apply({{y2}}) }
.endpoint_vars(),
y_opt);
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>() .add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>() .add_pass<gopt::FuseConvBiasZPass>()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册