diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 7488a3bd45be00daa724909033d1ad5d0ca9948d..6789e7e70bfc11fef43ec66ce0fd1b6fbf7d1868 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { continue; } #endif - if (!strcmp(argv[i], "--enable-chwn4")) { - mgb_log_warn("enable chwn4 optimization"); - graph_opt.graph_opt.enable_chwn4(); + +#define cb(_layout) \ + if (!strcmp(argv[i], "--enable-" #_layout)) { \ + mgb_log_warn("enable " #_layout " optimization"); \ + graph_opt.graph_opt.enable_##_layout(); \ + continue; \ + } + + cb(chwn4); + cb(nchw44); + cb(nchw88); + cb(nchw32); + cb(nhwcd4); +#undef cb + if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { + mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); + graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); + continue; + } + if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { + mgb_log_warn("enable fuse_conv_bias_with_z optimization"); + graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); continue; } #if MGB_ENABLE_JSON diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 53f31e6dc0d2ff2f08138d1f06f08abd5e476a94..db738e1e810d88200fe65f0fec2716ba824cd715 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( options().graph_opt.winograd_transform = false; gopt::transform_vars_inplace_with_winograd(dest_vars); } - if (options().graph_opt.transform_chwn4()) { - gopt::GraphOptimizer optimizer; - optimizer.apply_optimize_options(options().graph_opt); - options().graph_opt.layout_transform = - cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; - optimizer.apply_inplace(dest_vars); - } #if MGB_JIT if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { @@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( optimizer.apply_inplace(dest_vars); } #endif + gopt::GraphOptimizer optimizer; + optimizer.apply_optimize_options(options().graph_opt); + options().graph_opt.reset(); + optimizer.apply_inplace(dest_vars); const OprNodeArray* opr_seq = nullptr; CompSeqExtraInfo extra_info; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index d84cdc6c2ddde52e73982d84c48be3d0fbca7e81..229d101f1d9ba81d8b79ff462ac1eaef01987d5c 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { bool f16_io_comp = false; //! whether to enable conv bias nonlinearity fusion bool fuse_conv_bias_nonlinearity = false; + //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) + //! + z -> conv_bias(x, w, b, z) + bool fuse_conv_bias_with_z = false; enum LayoutTransform : uint32_t { DEFAULT, NHWCD4, ///< compute using NHWCD4 tensor format @@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { ///< used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; - //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) - //! + z -> conv_bias(x, w, b, z) - bool fuse_conv_bias_with_z = false; + + void reset() { + f16_io_f32_comp = false; + f16_io_comp = false; + fuse_conv_bias_nonlinearity = false; + fuse_conv_bias_with_z = false; + layout_transform = LayoutTransform::DEFAULT; + } #define SET(n) \ GraphCommonOptimizeOptions& enable_##n() { \ @@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { #undef SET #define SET(_trans, _trans_capital) \ GraphCommonOptimizeOptions& enable_##_trans() { \ + mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ layout_transform = LayoutTransform::_trans_capital; \ return *this; \ } \ diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index eb1e64a0644ed389debbc508d004f736a5e75209..ce22d07615391478f6b17fc87230dec023fb5cb7 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { const GraphOptimizer& GraphOptimizer::apply_optimize_options( const cg::GraphCommonOptimizeOptions& options) { + bool need_param_fuse = false; if (options.f16_io_comp) { add_pass(ConvertF32ToF16Pass::make(false)); + need_param_fuse = true; } if (options.f16_io_f32_comp) { add_pass(ConvertF32ToF16Pass::make(true)); + need_param_fuse = true; } if (options.transform_nhwcd4()) { add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(); + need_param_fuse = true; } if (options.transform_nchw88()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); + need_param_fuse = true; } if (options.transform_nchw44()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); + need_param_fuse = true; } if (options.transform_nchw32()) { add_pass(); @@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); + need_param_fuse = true; } if (options.transform_chwn4()) { add_pass(); @@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(); add_pass(); + need_param_fuse = true; } if (options.fuse_conv_bias_nonlinearity) { add_pass(); + need_param_fuse = true; } if (options.fuse_conv_bias_with_z) { add_pass(); add_pass(); + need_param_fuse = true; + } + if (need_param_fuse) { + add_pass(); } - add_pass(); return *this; }