diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index db738e1e810d88200fe65f0fec2716ba824cd715..88b9f23b7178316a312ac7acb101b7ea9ab43efd 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -474,8 +474,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( } #endif gopt::GraphOptimizer optimizer; - optimizer.apply_optimize_options(options().graph_opt); - options().graph_opt.reset(); + optimizer.add_passes_for_optimize_options(options().graph_opt, true); optimizer.apply_inplace(dest_vars); const OprNodeArray* opr_seq = nullptr; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 229d101f1d9ba81d8b79ff462ac1eaef01987d5c..0f5bddfd584b48771558a112376495bc4e66bd65 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -107,19 +107,17 @@ struct GraphCommonOptimizeOptions { }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; - void reset() { - f16_io_f32_comp = false; - f16_io_comp = false; - fuse_conv_bias_nonlinearity = false; - fuse_conv_bias_with_z = false; - layout_transform = LayoutTransform::DEFAULT; - } +#define SET(n) \ + GraphCommonOptimizeOptions& enable_##n() { \ + n = true; \ + return *this; \ + } \ + GraphCommonOptimizeOptions& disable_##n() { \ + n = false; \ + return *this; \ + } \ + bool has_set_##n() { return n == true; } -#define SET(n) \ - GraphCommonOptimizeOptions& enable_##n() { \ - n = true; \ - return *this; \ - } SET(f16_io_f32_comp); SET(f16_io_comp); SET(fuse_conv_bias_nonlinearity); @@ -131,7 +129,11 @@ struct GraphCommonOptimizeOptions { layout_transform = LayoutTransform::_trans_capital; \ return *this; \ } \ - bool transform_##_trans() const { \ + GraphCommonOptimizeOptions& disable_##_trans() { \ + layout_transform = LayoutTransform::DEFAULT; \ + return *this; \ + } \ + bool has_set_##_trans() const { \ return layout_transform == LayoutTransform::_trans_capital; \ } diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index aa287d37220e4a4ae036bb147ad5acf09ac6a20a..8552c3be13310294c7ccb2f5f345b12926331ae1 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -675,7 +675,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( if (inference_opt) { add_pass(); - apply_optimize_options(*inference_opt); + add_passes_for_optimize_options(*inference_opt); } @@ -704,56 +704,56 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { } } -const GraphOptimizer& GraphOptimizer::apply_optimize_options( + +const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( const cg::GraphCommonOptimizeOptions& options) { + return add_passes_for_optimize_options( + const_cast(options)); +} + +const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( + cg::GraphCommonOptimizeOptions& options, bool reset) { bool need_param_fuse = false; - if (options.f16_io_comp) { - add_pass(ConvertF32ToF16Pass::make(false)); - need_param_fuse = true; - } - if (options.f16_io_f32_comp) { - add_pass(ConvertF32ToF16Pass::make(true)); - need_param_fuse = true; + +#define cb(_option, _passes) \ + if (options.has_set_##_option()) { \ + _passes need_param_fuse = true; \ + if (reset) { \ + options.disable_##_option(); \ + } \ } - if (options.transform_nhwcd4()) { + cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); + cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); + + cb(nhwcd4, { add_pass(); add_pass(ConvertFormatPass::make_nhwcd4_converter()); - need_param_fuse = true; - } - if (options.transform_nchw88()) { - add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); - need_param_fuse = true; - } - if (options.transform_nchw44()) { - add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); - need_param_fuse = true; - } - if (options.transform_nchw32()) { + }); + cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); + cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); + cb(nchw32, { add_pass(); add_pass(); add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); - need_param_fuse = true; - } - if (options.transform_chwn4()) { + }); + cb(chwn4, { add_pass(); add_pass(); add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(); add_pass(); - need_param_fuse = true; - } + }); - if (options.fuse_conv_bias_nonlinearity) { - add_pass(); - need_param_fuse = true; - } - if (options.fuse_conv_bias_with_z) { + cb(fuse_conv_bias_nonlinearity, { add_pass(); }); + cb(fuse_conv_bias_with_z, { add_pass(); add_pass(); - need_param_fuse = true; - } + }); + +#undef cb + if (need_param_fuse) { add_pass(); } diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index f5f41cdb31c2e3f8a2b789037e8e15518ee61045..56d79f4c62ab6e70059d30e6bc3cdad35d4d58d0 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -468,9 +468,16 @@ namespace gopt { static VarNode* var_replace_lookup(VarNode *var); /** - * \brief apply optimize options + * \brief add pass indicated by optimize options. + * + * \param options common options + * \param reset if set true, it will reset options when add passes. */ - const GraphOptimizer& apply_optimize_options( + const GraphOptimizer& add_passes_for_optimize_options( + cg::GraphCommonOptimizeOptions& options, + bool reset = false); + + const GraphOptimizer& add_passes_for_optimize_options( const cg::GraphCommonOptimizeOptions& options); };