diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 9201a5daa1d2c251999f2e1a186edc2c5332bace..1f63a675b41928349c6d27c8f33428671993465a 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -542,7 +542,8 @@ def optimize_for_inference( use_nchw32=False, fuse_conv_bias_with_z=False, use_nchw88=False, - use_nchw44=False + use_nchw44=False, + use_chwn4=False ): """optimize computing graph for inference @@ -566,6 +567,8 @@ def optimize_for_inference( times. :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for nvidia tensorcore. + :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for + nvidia tensorcore. :return: list of transformed vars corresponding to given output vars @@ -589,6 +592,7 @@ def optimize_for_inference( "use_nchw32": "nchw2nchw32", "use_nchw88": "nchw2nchw88", "use_nchw44": "nchw2nchw44", + "use_chwn4": "nchw42chwn4", }.items(): if settings[k]: assert ( diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index 7b7659498f49678ed37f8c09f11a9230257d43a3..58f8a61d889d7f1ff3df173a157c5dc8810f41bc 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw32, NCHW2NCHW32); + SET(nchw42chwn4, NCHW42CHWN4); #undef SET }; diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 3d67486b64ab1e2cacc5ba638c9b2a6b0679e7ef..cd62283e5a5b5cdf72b62ad7947daab11150c5e3 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs): 'enable_hwcd4': 'use_nhwcd4', 'enable_nchw88': 'use_nchw88', 'enable_nchw44': 'use_nchw44', - 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', 'enable_nchw32': 'use_nchw32', + 'enable_chwn4': 'use_chwn4', + 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', } kwargs = {} @@ -398,6 +399,12 @@ def main(): help='transform the model format from NCHW4 to NCHW32 ' 'for inference on nvidia TensoCore' ) + parser.add_argument( + '--enable-chwn4', + action='store_true', + help='transform the model format to CHWN4 ' + 'for inference, mainly used for nvidia tensorcore' + ) parser.add_argument( '--enable-fuse-conv-bias-with-z', action='store_true', diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index a981a84d5af68b1349b1be3292c2b4ab7786767b..7a7957b589ab72f707295f01595c03937e7e1663 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options( add_pass(); add_pass(); } + if (options->transform_nchw42chwn4()) { + add_pass(); + add_pass(); + add_pass(EnableCHWN4Pass::make_chwn4_converter()); + add_pass(); + add_pass(); + } if (options->fuse_conv_bias_nonlinearity) { add_pass(); diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index cd0b30155bde1cc4257fcaa77d4d91bac87ef52c..2f950129454af06edbeacf4733aab9daba58f9a3 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -395,6 +395,8 @@ namespace gopt { NCHW2NCHW44, ///< compute using NCHW44 tensor format NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for ///< tensorcore + NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed + ///< from NCHW4, mainly used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) @@ -424,6 +426,7 @@ namespace gopt { SET(nchw2nchw88, NCHW2NCHW88); SET(nchw2nchw44, NCHW2NCHW44); SET(nchw2nchw32, NCHW2NCHW32); + SET(nchw42chwn4, NCHW42CHWN4); #undef SET }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index b271b7643b4e9744c809b233ba5429105c574b59..7e03712dbffb71a627f787f7e5cb2b2fd604eebb 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) { y4 = opr::TypeCvt::make(y4, dtype::Float32()); SymbolVar y_opt; SymbolVar y_cudnn; - unpack_vector( - gopt::GraphOptimizer{} - .add_pass() - .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) - .add_pass() - .apply({{y4}}) - .endpoint_vars(), - y_opt); + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw42chwn4(); + unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); + } unpack_vector(gopt::GraphOptimizer{} .add_pass() .add_pass() @@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param); SymbolVar y_opt; SymbolVar y_cudnn; - unpack_vector(gopt::GraphOptimizer{} - .add_pass() - .add_pass() - .add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter()) - .apply({{y2}}) - .endpoint_vars(), - y_opt); + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw42chwn4(); + unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); + } unpack_vector(gopt::GraphOptimizer{} .add_pass() .add_pass()