From a3560fa10ef2e5de86c065e48daafef629b81fc0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 10 May 2020 20:43:46 +0800 Subject: [PATCH] feat(gopt): add tranform to chwn4 to optimize_for_inference GitOrigin-RevId: 4d1a9c6c8410904ea4da17a1bed2ad06ce369869 --- python_module/megengine/_internal/__init__.py | 6 ++++- python_module/src/swig/misc.i | 1 + sdk/load-and-run/dump_with_testcase_mge.py | 9 ++++++- src/gopt/impl/framework.cpp | 7 ++++++ src/gopt/include/megbrain/gopt/framework.h | 3 +++ src/gopt/test/inference.cpp | 25 ++++++++----------- 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 9201a5da..1f63a675 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -542,7 +542,8 @@ def optimize_for_inference( use_nchw32=False, fuse_conv_bias_with_z=False, use_nchw88=False, - use_nchw44=False + use_nchw44=False, + use_chwn4=False ): """optimize computing graph for inference @@ -566,6 +567,8 @@ def optimize_for_inference( times. :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for nvidia tensorcore. + :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for + nvidia tensorcore. :return: list of transformed vars corresponding to given output vars @@ -589,6 +592,7 @@ def optimize_for_inference( "use_nchw32": "nchw2nchw32", "use_nchw88": "nchw2nchw88", "use_nchw44": "nchw2nchw44", + "use_chwn4": "nchw42chwn4", }.items(): if settings[k]: assert ( diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index 7b765949..58f8a61d 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw32, NCHW2NCHW32); + SET(nchw42chwn4, NCHW42CHWN4); #undef SET }; diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 3d67486b..cd62283e 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs): 'enable_hwcd4': 'use_nhwcd4', 'enable_nchw88': 'use_nchw88', 'enable_nchw44': 'use_nchw44', - 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', 'enable_nchw32': 'use_nchw32', + 'enable_chwn4': 'use_chwn4', + 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', } kwargs = {} @@ -398,6 +399,12 @@ def main(): help='transform the model format from NCHW4 to NCHW32 ' 'for inference on nvidia TensoCore' ) + parser.add_argument( + '--enable-chwn4', + action='store_true', + help='transform the model format to CHWN4 ' + 'for inference, mainly used for nvidia tensorcore' + ) parser.add_argument( '--enable-fuse-conv-bias-with-z', action='store_true', diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index a981a84d..7a7957b5 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options( add_pass(); add_pass(); } + if (options->transform_nchw42chwn4()) { + add_pass(); + add_pass(); + add_pass(EnableCHWN4Pass::make_chwn4_converter()); + add_pass(); + add_pass(); + } if (options->fuse_conv_bias_nonlinearity) { add_pass(); diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index cd0b3015..2f950129 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -395,6 +395,8 @@ namespace gopt { NCHW2NCHW44, ///< compute using NCHW44 tensor format NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for ///< tensorcore + NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed + ///< from NCHW4, mainly used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) @@ -424,6 +426,7 @@ namespace gopt { SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw32, NCHW2NCHW32); + SET(nchw42chwn4, NCHW42CHWN4); #undef SET }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index b271b764..7e03712d 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) { y4 = opr::TypeCvt::make(y4, dtype::Float32()); SymbolVar y_opt; SymbolVar y_cudnn; - unpack_vector( - gopt::GraphOptimizer{} - .add_pass() - .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) - .add_pass() - .apply({{y4}}) - .endpoint_vars(), - y_opt); + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw42chwn4(); + unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); + } unpack_vector(gopt::GraphOptimizer{} .add_pass() .add_pass() @@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); SymbolVar y_opt; SymbolVar y_cudnn; - unpack_vector(gopt::GraphOptimizer{} - .add_pass() - .add_pass() - .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) - .apply({{y2}}) - .endpoint_vars(), - y_opt); + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw42chwn4(); + unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); + } unpack_vector(gopt::GraphOptimizer{} .add_pass() .add_pass() -- GitLab