From 6d6b42bb77653cb92e8bdf07f0641f0e0f6e7353 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 May 2020 17:33:43 +0800 Subject: [PATCH] refactor(gopt): refactor interface of add passes for common optimizations GitOrigin-RevId: d0f3819c3a6e969430ebf8a65e26fc2e77ae8aeb --- src/core/impl/graph/cg_impl.cpp | 3 +- src/core/include/megbrain/graph/cg.h | 28 ++++----- src/gopt/impl/framework.cpp | 66 +++++++++++----------- src/gopt/include/megbrain/gopt/framework.h | 11 +++- 4 files changed, 58 insertions(+), 50 deletions(-) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index db738e1e..88b9f23b 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -474,8 +474,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( } #endif gopt::GraphOptimizer optimizer; - optimizer.apply_optimize_options(options().graph_opt); - options().graph_opt.reset(); + optimizer.add_passes_for_optimize_options(options().graph_opt, true); optimizer.apply_inplace(dest_vars); const OprNodeArray* opr_seq = nullptr; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 229d101f..0f5bddfd 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -107,19 +107,17 @@ struct GraphCommonOptimizeOptions { }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; - void reset() { - f16_io_f32_comp = false; - f16_io_comp = false; - fuse_conv_bias_nonlinearity = false; - fuse_conv_bias_with_z = false; - layout_transform = LayoutTransform::DEFAULT; - } +#define SET(n) \ + GraphCommonOptimizeOptions& enable_##n() { \ + n = true; \ + return *this; \ + } \ + GraphCommonOptimizeOptions& disable_##n() { \ + n = false; \ + return *this; \ + } \ + bool has_set_##n() { return n == true; } -#define SET(n) \ - GraphCommonOptimizeOptions& enable_##n() { \ - n = true; \ - return *this; \ - } SET(f16_io_f32_comp); SET(f16_io_comp); SET(fuse_conv_bias_nonlinearity); @@ -131,7 +129,11 @@ struct GraphCommonOptimizeOptions { layout_transform = LayoutTransform::_trans_capital; \ return *this; \ } \ - bool transform_##_trans() const { \ + GraphCommonOptimizeOptions& disable_##_trans() { \ + layout_transform = LayoutTransform::DEFAULT; \ + return *this; \ + } \ + bool has_set_##_trans() const { \ return layout_transform == LayoutTransform::_trans_capital; \ } diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index aa287d37..8552c3be 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -675,7 +675,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( if (inference_opt) { add_pass(); - apply_optimize_options(*inference_opt); + add_passes_for_optimize_options(*inference_opt); } @@ -704,56 +704,56 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { } } -const GraphOptimizer& GraphOptimizer::apply_optimize_options( + +const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( const cg::GraphCommonOptimizeOptions& options) { + return add_passes_for_optimize_options( + const_cast(options)); +} + +const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( + cg::GraphCommonOptimizeOptions& options, bool reset) { bool need_param_fuse = false; - if (options.f16_io_comp) { - add_pass(ConvertF32ToF16Pass::make(false)); - need_param_fuse = true; - } - if (options.f16_io_f32_comp) { - add_pass(ConvertF32ToF16Pass::make(true)); - need_param_fuse = true; + +#define cb(_option, _passes) \ + if (options.has_set_##_option()) { \ + _passes need_param_fuse = true; \ + if (reset) { \ + options.disable_##_option(); \ + } \ } - if (options.transform_nhwcd4()) { + cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); + cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); + + cb(nhwcd4, { add_pass(); add_pass(ConvertFormatPass::make_nhwcd4_converter()); - need_param_fuse = true; - } - if (options.transform_nchw88()) { - add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); - need_param_fuse = true; - } - if (options.transform_nchw44()) { - add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); - need_param_fuse = true; - } - if (options.transform_nchw32()) { + }); + cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); + cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); + cb(nchw32, { add_pass(); add_pass(); add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); - need_param_fuse = true; - } - if (options.transform_chwn4()) { + }); + cb(chwn4, { add_pass(); add_pass(); add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(); add_pass(); - need_param_fuse = true; - } + }); - if (options.fuse_conv_bias_nonlinearity) { - add_pass(); - need_param_fuse = true; - } - if (options.fuse_conv_bias_with_z) { + cb(fuse_conv_bias_nonlinearity, { add_pass(); }); + cb(fuse_conv_bias_with_z, { add_pass(); add_pass(); - need_param_fuse = true; - } + }); + +#undef cb + if (need_param_fuse) { add_pass(); } diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index f5f41cdb..56d79f4c 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -468,9 +468,16 @@ namespace gopt { static VarNode* var_replace_lookup(VarNode *var); /** - * \brief apply optimize options + * \brief add pass indicated by optimize options. + * + * \param options common options + * \param reset if set true, it will reset options when add passes. */ - const GraphOptimizer& apply_optimize_options( + const GraphOptimizer& add_passes_for_optimize_options( + cg::GraphCommonOptimizeOptions& options, + bool reset = false); + + const GraphOptimizer& add_passes_for_optimize_options( const cg::GraphCommonOptimizeOptions& options); }; -- GitLab