diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 1f63a675b41928349c6d27c8f33428671993465a..58b8f1515fec952bfea2b808537f91a5f8dd73b2 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -588,11 +588,11 @@ def optimize_for_inference( layout_tranform = None for k, v in { - "use_nhwcd4": "nchw2nhwcd4", - "use_nchw32": "nchw2nchw32", - "use_nchw88": "nchw2nchw88", - "use_nchw44": "nchw2nchw44", - "use_chwn4": "nchw42chwn4", + "use_nhwcd4": "nhwcd4", + "use_nchw32": "nchw32", + "use_nchw88": "nchw88", + "use_nchw44": "nchw44", + "use_chwn4": "chwn4", }.items(): if settings[k]: assert ( diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index 58f8a61d889d7f1ff3df173a157c5dc8810f41bc..d554bf5f30ea579558158d8292989bd8eb44310e 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -80,11 +80,11 @@ struct _OptimizeForInferenceOptions { #define SET(_trans, _trans_capital) \ void enable_##_trans(); \ - SET(nchw2nhwcd4, NCHW2NHWCD4); - SET(nchw2nchw88, NCHW2NCHW88); - SET(nchw2nchw44, NCHW2NCHW44); - SET(nchw2nchw32, NCHW2NCHW32); - SET(nchw42chwn4, NCHW42CHWN4); + SET(nhwcd4, NHWCD4); + SET(nchw88, NCHW88); + SET(nchw44, NCHW44); + SET(nchw32, NCHW32); + SET(chwn4, CHWN4); #undef SET }; diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 7a7957b589ab72f707295f01595c03937e7e1663..0cca3cfc69a7c1573f19e86dbd3190be7918a5ba 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -708,23 +708,23 @@ void GraphOptimizer::apply_optimize_options( if (options->f16_io_f32_comp) { add_pass(ConvertF32ToF16Pass::make(true)); } - if (options->transform_nchw2nhwcd4()) { + if (options->transform_nhwcd4()) { add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(); } - if (options->transform_nchw2nchw88()) { + if (options->transform_nchw88()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); } - if (options->transform_nchw2nchw44()) { + if (options->transform_nchw44()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); } - if (options->transform_nchw2nchw32()) { + if (options->transform_nchw32()) { add_pass(); add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); } - if (options->transform_nchw42chwn4()) { + if (options->transform_chwn4()) { add_pass(); add_pass(); add_pass(EnableCHWN4Pass::make_chwn4_converter()); diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index 2f950129454af06edbeacf4733aab9daba58f9a3..77a2f0f8f37084ebca8e46a41902667b08ec22d1 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -390,13 +390,13 @@ namespace gopt { bool fuse_conv_bias_nonlinearity = false; enum LayoutTransform : uint32_t { DEFAULT, - NCHW2NHWCD4, ///< compute using NHWCD4 tensor format - NCHW2NCHW88, ///< compute using NCHW88 tensor format - NCHW2NCHW44, ///< compute using NCHW44 tensor format - NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for - ///< tensorcore - NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed - ///< from NCHW4, mainly used for cuda + NHWCD4, ///< compute using NHWCD4 tensor format + NCHW88, ///< compute using NCHW88 tensor format + NCHW44, ///< compute using NCHW44 tensor format + NCHW32, ///< compute using NCHW32 tensor format, used for + ///< tensorcore + CHWN4, ///< compute using CHWN4 tensor format, transformed mainly + ///< used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) @@ -422,11 +422,11 @@ namespace gopt { return layout_transform == LayoutTransform::_trans_capital; \ } - SET(nchw2nhwcd4, NCHW2NHWCD4); - SET(nchw2nchw88, NCHW2NCHW88); - SET(nchw2nchw44, NCHW2NCHW44); - SET(nchw2nchw32, NCHW2NCHW32); - SET(nchw42chwn4, NCHW42CHWN4); + SET(nhwcd4, NHWCD4); + SET(nchw88, NCHW88); + SET(nchw44, NCHW44); + SET(nchw32, NCHW32); + SET(chwn4, CHWN4); #undef SET }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 7e03712dbffb71a627f787f7e5cb2b2fd604eebb..40636d434756c9d0a2e43dc9b6604b8318e15dc5 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -992,7 +992,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, @@ -1051,7 +1051,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, @@ -1102,7 +1102,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(opr::Convolution::Param::Format::NCHW, @@ -1147,7 +1147,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Qint8) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(opr::ConvBias::Param::Format::NHWCD4, @@ -1199,7 +1199,7 @@ TEST(TestGoptInference, ConvertFormatPadIC) { auto y = opr::Convolution::make(concat, w1, param); SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); HostTensorND host_y_opt, host_y; @@ -1285,7 +1285,7 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { y_y = opr::Convolution::make(y_expand, w3, param), y = y_y + y_tmp; SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4().enable_fuse_conv_bias_nonlinearity(); + options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(3u, find_opr(y_opt).input().size()); graph->compile({{y_opt, {}}}) @@ -1516,7 +1516,7 @@ TEST(TestEnableTensorCore, SmallInputShape) { SymbolVar y_no_tc; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nchw32().enable_fuse_conv_bias_nonlinearity(); + options.enable_nchw32().enable_fuse_conv_bias_nonlinearity(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } { @@ -1581,7 +1581,7 @@ TEST(TestEnableTensorCore, ConvBiasWithZ) { SymbolVar y_no_tc; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); + options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } { @@ -1649,12 +1649,12 @@ TEST(TestGoptInference, EnableTensorCore) { SymbolVar y_no_tc; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); + options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); } { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); + options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y4}, options), y_no_tc); } auto nr_dimshuffle = find_opr_num(y_opt); @@ -1855,7 +1855,7 @@ TEST(TestEnableTensorCore, ShuffleMerge) { SymbolVar y_no_tc; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); + options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } { @@ -1923,7 +1923,7 @@ TEST(FuseConvBiasZPass, Basic) { auto options = gopt::OptimizeForInferenceOptions{}; options.enable_fuse_conv_bias_nonlinearity() .enable_fuse_conv_bias_with_z() - .enable_nchw2nchw32(); + .enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y1}, options), y_opt); } auto nr_elemwisemultitype = find_opr_num(y_opt); @@ -1940,7 +1940,7 @@ TEST(FuseConvBiasZPass, Basic) { auto options = gopt::OptimizeForInferenceOptions{}; options.enable_fuse_conv_bias_nonlinearity() .enable_fuse_conv_bias_with_z() - .enable_nchw2nchw32(); + .enable_nchw32(); unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); } @@ -2013,7 +2013,7 @@ TEST(TestGoptInference, EnableCHWN4) { SymbolVar y_cudnn; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw42chwn4(); + options.enable_chwn4(); unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); } unpack_vector(gopt::GraphOptimizer{} @@ -2099,7 +2099,7 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { SymbolVar y_cudnn; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw42chwn4(); + options.enable_chwn4(); unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); } unpack_vector(gopt::GraphOptimizer{} @@ -2386,7 +2386,7 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { SymbolVar y_opt; { auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nchw88(); + options.enable_nchw88(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } @@ -2467,7 +2467,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nchw44(); + options.enable_nchw44(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, diff --git a/src/opr/test/io.cpp b/src/opr/test/io.cpp index 7bff8f77c0485ea6bda0703b64614bef10cd9241..7930bffb655e37bdf520160fee6525d43d9fc089 100644 --- a/src/opr/test/io.cpp +++ b/src/opr/test/io.cpp @@ -501,7 +501,7 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) { auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_nchw2nhwcd4(); + options.enable_nhwcd4(); SymbolVar y_opt = gopt::optimize_for_inference({y}, options)[0].rename("out");