From d07cfdcbe90f0bc0d4f89f2daee654d1fcbbea19 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 7 Jan 2021 02:01:07 +0800 Subject: [PATCH] refactor(mgb): move convolution mixin to search_policy GitOrigin-RevId: 81e32da034f7bfd7d2d42cc98e949d506da184df --- sdk/load-and-run/src/mgblar.cpp | 3 +- src/gopt/impl/inference.cpp | 25 ++-- src/gopt/include/megbrain/gopt/inference.h | 3 +- src/opr/impl/dnn/convolution.cpp | 23 +--- src/opr/impl/search_policy/algo_chooser.cpp | 3 +- .../search_policy/algo_chooser_helper.cpp | 33 +++++ src/opr/impl/search_policy/profiler.cpp | 1 + .../include/megbrain/opr/dnn/convolution.h | 125 +++++------------- .../megbrain/opr/search_policy/algo_chooser.h | 1 + .../opr/search_policy/algo_chooser_helper.h | 80 +++++++++++ .../megbrain/opr/search_policy/profiler.h | 9 +- 11 files changed, 176 insertions(+), 130 deletions(-) create mode 100644 src/opr/impl/search_policy/algo_chooser_helper.cpp create mode 100644 src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index b7888bc57..1ccaa921e 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -19,6 +19,7 @@ #include "megbrain/graph/extern_copr_api.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/opr/utility.h" #include "megbrain/plugin/cpu_dispatch_checker.h" #include "megbrain/plugin/num_range_checker.h" @@ -691,7 +692,7 @@ void run_test_st(Args &env) { } mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); - using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::HEURISTIC; #if MGB_ENABLE_FASTRUN if (env.use_fast_run) { diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 7d99c8734..0acbe6153 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -15,6 +15,7 @@ #include "megbrain/graph/event.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/local.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/utils/shared_set.h" #include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/opr/basic_arith.h" @@ -116,8 +117,8 @@ SymbolVarArray gopt::optimize_for_inference( namespace { void modify_conv_strategy( - opr::mixin::Convolution& conv, - opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { + opr::mixin::AlgoChooserHelper& conv, + opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { auto policy = conv.execution_policy_transient(); policy.strategy = strategy; conv.set_execution_policy(policy); @@ -126,13 +127,13 @@ void modify_conv_strategy( template void inplace_conv_opr_modifier( OperatorNodeBase& opr, - opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { + opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { modify_conv_strategy( opr.cast_final_safe(), strategy); } -void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, +void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv, size_t workspace_limit) { auto policy = conv.execution_policy_transient(); policy.workspace_limit = workspace_limit; @@ -159,9 +160,9 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, void gopt::modify_opr_algo_strategy_inplace( const VarNodeArrayView& dest_vars, - opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { + opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { #if !MGB_ENABLE_FASTRUN - using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { mgb_throw(MegBrainError, "fastrun is disabled at compile time"); } @@ -190,16 +191,16 @@ void gopt::modify_opr_algo_strategy_inplace( void gopt::enable_opr_algo_profiling_inplace( const VarNodeArrayView& dest_vars) { - modify_opr_algo_strategy_inplace(dest_vars, - opr::mixin::Convolution::ExecutionPolicy:: - Strategy::PROFILE); + modify_opr_algo_strategy_inplace( + dest_vars, + opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE); } void gopt::enable_opr_use_profiling_cache_inplace( const VarNodeArrayView& dest_vars) { - modify_opr_algo_strategy_inplace(dest_vars, - opr::mixin::Convolution::ExecutionPolicy:: - Strategy::PROFILE_HEURISTIC); + modify_opr_algo_strategy_inplace( + dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: + Strategy::PROFILE_HEURISTIC); } diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 81cf753be..0ad3617b8 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -14,6 +14,7 @@ #include "megbrain/gopt/framework.h" #include "megbrain/graph/cg.h" #include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" namespace mgb { namespace gopt { @@ -342,7 +343,7 @@ namespace gopt { */ void modify_opr_algo_strategy_inplace( const VarNodeArrayView& dest_vars, - opr::mixin::Convolution::ExecutionPolicy::Strategy strategy); + opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy); /*! * \brief enable PROFILE execution strategy for oprs with multiple diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index a6ed0f936..1fa9e8310 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -13,7 +13,7 @@ #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/io.h" #include "megbrain/opr/search_policy/algo_chooser.h" -#include "megbrain/opr/search_policy/profiler.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/graph/grad_impl.h" #include "megbrain/system.h" @@ -38,18 +38,9 @@ using intl::WorkspaceLimitGetter; /* ==================== misc impl ==================== */ -mixin::Convolution::~Convolution() = default; - -void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) { - mgb_throw_if( - m_policy_accessed, InternalError, - "attempt to modify ExecutionPolicy after it has been accessed"); - m_policy = policy; -} - template -void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( - cg::OperatorNodeBase* self) { +void mixin::ConvolutionBackwardDataMixin:: + init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) { using namespace cg::static_infer; auto&& mgr = self->owner_graph()->static_infer_manager(); @@ -93,7 +84,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( }; inp_deps.push_back({self->output(0), DepType::SHAPE}); auto workspace_dep_var = - WorkspaceLimitGetter::register_to_graph(self->owner_graph()); + intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); if (workspace_dep_var) { inp_deps.push_back({workspace_dep_var, DepType::VALUE}); } @@ -101,11 +92,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( {SourceType::DEP, inp_deps, infer_wk}); } -#define IMPL_CONV(_cls) \ - std::pair _cls::param_blob() const { \ - return {¶m(), sizeof(Param)}; \ - } \ - MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) +#define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final : public cg::GraphExecutable::ExecDependency { diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 978dee036..d194f657e 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -11,6 +11,7 @@ */ #include "megbrain/opr/search_policy/algo_chooser.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/opr/search_policy/profiler.h" #include "../internal/invoke.h" @@ -200,7 +201,7 @@ size_t AlgoChooser::setup_algo(const TensorLayoutArray& layouts, template typename AlgoChooser::ImplAlgo AlgoChooser::get_algo( ExeContext& ctx) { - using S = mixin::Convolution::ExecutionPolicy::Strategy; + using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); switch (ctx.mgb_opr()->execution_policy().strategy) { case S::HEURISTIC: diff --git a/src/opr/impl/search_policy/algo_chooser_helper.cpp b/src/opr/impl/search_policy/algo_chooser_helper.cpp new file mode 100644 index 000000000..5ccf388fc --- /dev/null +++ b/src/opr/impl/search_policy/algo_chooser_helper.cpp @@ -0,0 +1,33 @@ +/** + * \file src/opr/impl/search_policy/algo_chooser_helper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/search_policy/algo_chooser_helper.h" +#include "megbrain/opr/search_policy/algo_chooser.h" +#include "megbrain/graph/cg.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; +using namespace mixin; +/* ==================== misc impl ==================== */ + +AlgoChooserHelper::~AlgoChooserHelper() = default; + +void AlgoChooserHelper::set_execution_policy(const ExecutionPolicy& policy) { + mgb_throw_if( + m_policy_accessed, InternalError, + "attempt to modify ExecutionPolicy after it has been accessed"); + m_policy = policy; +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/search_policy/profiler.cpp b/src/opr/impl/search_policy/profiler.cpp index 76712386d..3365bf856 100644 --- a/src/opr/impl/search_policy/profiler.cpp +++ b/src/opr/impl/search_policy/profiler.cpp @@ -13,6 +13,7 @@ #include "megbrain/opr/search_policy/profiler.h" #include "../internal/invoke.h" +#include "../internal/megdnn_opr_wrapper.inl" #if MGB_ROCM #include "hcc_detail/hcc_defs_prologue.h" diff --git a/src/opr/include/megbrain/opr/dnn/convolution.h b/src/opr/include/megbrain/opr/dnn/convolution.h index 5827f59bf..4af1bfc22 100644 --- a/src/opr/include/megbrain/opr/dnn/convolution.h +++ b/src/opr/include/megbrain/opr/dnn/convolution.h @@ -11,6 +11,7 @@ #pragma once #include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/utils/persistent_cache.h" #include "megbrain/opr/param_defs.h" #include "megdnn/oprs/nn.h" @@ -19,68 +20,14 @@ namespace mgb { namespace opr { namespace mixin { -/*! - * \brief Convolution base class - */ -class Convolution { - public: - using ExecutionPolicy = megdnn::param::ExecutionPolicy; - using AlgorithmInfo = megdnn::detail::Algorithm::Info; - using AlgoChooserHook = - std::function; - - const ExecutionPolicy& execution_policy() const { - if (!m_policy_accessed) { - m_policy_accessed = true; - } - return m_policy; - } - - /*! - * \brief get current policy without marking it as having been accessed - * - * This is primarily used for getting current policy before calling - * set_execution_policy(). - */ - const ExecutionPolicy& execution_policy_transient() const { - return m_policy; - } - - /*! - * \brief modify execution policy - * - * Exception would be thrown if execution_policy() has been accessed, - * since it would influence cache and many other decisions. - */ - void set_execution_policy(const ExecutionPolicy& policy); - - AlgoChooserProfileCache& profile_cache() const; - - virtual std::pair param_blob() const = 0; - - /*! - * \brief register a hook to implement custom algo chooser - */ - void setup_algo_chooser(AlgoChooserHook&& func) { - m_algo_chooser = func; - } - AlgoChooserHook algo_chooser() const { - return m_algo_chooser; - } - - protected: - ~Convolution(); - - mutable bool m_policy_accessed = false; - ExecutionPolicy m_policy; - - AlgoChooserHook m_algo_chooser; +class ConvolutionBackwardDataMixin : public cg::OperatorNodeMixinBase { +protected: + //! init output desc for conv backward data oprs; it handles both grad + //! usage and deconv usage + template + static void init_output_static_infer_desc_for_bwd_data( + cg::OperatorNodeBase* self); - //! init output desc for conv backward data oprs; it handles both grad - //! usage and deconv usage - template - static void init_output_static_infer_desc_for_bwd_data( - cg::OperatorNodeBase* self); }; class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { @@ -153,7 +100,7 @@ class ConvolutionTestingPeer; } // namespace testing MGB_DEFINE_OPR_CLASS(ConvolutionForward, - intl::ConvolutionForwardBase, public mixin::Convolution) // { + intl::ConvolutionForwardBase, public mixin::AlgoChooserHelper) // { void init_output_dtype() override; size_t get_workspace_size_bytes( @@ -183,12 +130,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionForward, const ExecutionPolicy &policy = {}, const OperatorNodeConfig &config = {}); - std::pair param_blob() const override; }; using Convolution = ConvolutionForward; MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { void init_output_dtype() override; size_t get_workspace_size_bytes( @@ -240,7 +186,6 @@ public: const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); - std::pair param_blob() const override; static void check_winograd_param_valid( const megdnn::ConvBias::WinogradParam& param, @@ -253,10 +198,12 @@ using ConvBias = ConvBiasForward; /*! * \brief Can be used in two ways: compute gradient of conv, or deconv */ -MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, +MGB_DEFINE_OPR_CLASS( + ConvolutionBackwardData, cg::SingleCNOperatorNodeBaseT< - mixin::MegDNNOprHolderImpl>, - public mixin::Convolution) // { + mixin::MegDNNOprHolderImpl>, + public mixin::AlgoChooserHelper, + public mixin::ConvolutionBackwardDataMixin) // { void init_output_static_infer_desc() override; void init_output_dtype() override; void init_output_format() override; @@ -296,12 +243,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, return make(filter, data, param, policy, config); } - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, intl::MegDNNOprWrapperBwd, - public mixin::Convolution ) // { + public mixin::AlgoChooserHelper ) // { size_t get_workspace_size_bytes( @@ -318,7 +264,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, const ExecutionPolicy &policy = {}, const OperatorNodeConfig &config = {}); - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS(MaskConvolution, @@ -350,7 +295,7 @@ public: MGB_DEFINE_OPR_CLASS(Convolution3DForward, intl::MegDNNOprWrapperFwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { void init_output_dtype() override; size_t get_workspace_size_bytes( @@ -368,17 +313,18 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward, const ExecutionPolicy &policy = {}, const OperatorNodeConfig &config = {}); - std::pair param_blob() const override; }; using Convolution3D = Convolution3DForward; /*! * \brief Can be used in two ways: compute gradient of conv, or deconv */ -MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, +MGB_DEFINE_OPR_CLASS( + Convolution3DBackwardData, cg::SingleCNOperatorNodeBaseT< - mixin::MegDNNOprHolderImpl>, - public mixin::Convolution) // { + mixin::MegDNNOprHolderImpl>, + public mixin::AlgoChooserHelper, + public mixin::ConvolutionBackwardDataMixin) // { void init_output_static_infer_desc() override; void add_input_layout_constraint() override; @@ -416,12 +362,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, return make(filter, data, param, policy, config); } - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, intl::MegDNNOprWrapperBwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { size_t get_workspace_size_bytes( const TensorShapeArray &input_shapes, @@ -437,12 +382,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, const ExecutionPolicy &policy = {}, const OperatorNodeConfig &config = {}); - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS(LocalShareForward, intl::MegDNNOprWrapperFwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { void init_output_dtype() override; void init_output_format() override; @@ -457,7 +401,6 @@ public: static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {}, const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); - std::pair param_blob() const override; }; using LocalShare = LocalShareForward; @@ -465,7 +408,8 @@ MGB_DEFINE_OPR_CLASS( LocalShareBackwardData, cg::SingleCNOperatorNodeBaseT< mixin::MegDNNOprHolderImpl>, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper, + public mixin::ConvolutionBackwardDataMixin) // { void init_output_static_infer_desc() override; void init_output_dtype() override; @@ -485,13 +429,12 @@ public: const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS( LocalShareBackwardFilter, intl::MegDNNOprWrapperBwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { size_t get_workspace_size_bytes( const TensorShapeArray& input_shapes, @@ -506,12 +449,11 @@ public: const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); - std::pair param_blob() const override; }; MGB_DEFINE_OPR_CLASS(DeformableConvForward, intl::MegDNNOprWrapperFwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { public: DeformableConvForward( VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask, @@ -525,7 +467,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward, const ExecutionPolicy &policy = {}, const OperatorNodeConfig &config = {}); - std::pair param_blob() const override; private: void init_output_dtype() override; void init_output_format() override; @@ -537,7 +478,8 @@ using DeformableConv = DeformableConvForward; MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData, intl::DeformableConvBackwardDataBase, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper, + public mixin::ConvolutionBackwardDataMixin) // { public: DeformableConvBackwardData( VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, @@ -557,7 +499,6 @@ public: const OperatorNodeConfig& config = {}); void scn_do_execute() override; - std::pair param_blob() const override; private: void get_output_var_shape(const TensorShapeArray& inp_shape, @@ -578,7 +519,7 @@ private: MGB_DEFINE_OPR_CLASS( DeformableConvBackwardFilter, intl::MegDNNOprWrapperBwd, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { public: DeformableConvBackwardFilter( VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, @@ -592,7 +533,6 @@ public: const OperatorNodeConfig& config = {}); void scn_do_execute() override; - std::pair param_blob() const override; private: size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, @@ -601,7 +541,7 @@ private: }; MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, - public mixin::Convolution) // { + public mixin::AlgoChooserHelper) // { void init_output_dtype() override; size_t get_workspace_size_bytes( @@ -650,7 +590,6 @@ public: const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); - std::pair param_blob() const override; }; using BatchConvBias = BatchConvBiasForward; diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index 0804ab87f..b76874667 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -13,6 +13,7 @@ #pragma once #include "megbrain/opr/search_policy/profiler.h" +#include "megbrain/opr/dnn/convolution.h" template struct MegDNNOpr2MGBOpr; diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h new file mode 100644 index 000000000..c0b8699a2 --- /dev/null +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h @@ -0,0 +1,80 @@ +/** + * \file src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain/graph/operator_node.h" +#include "megbrain/opr/param_defs.h" +#include "megdnn/oprs/base.h" +#include "megdnn/oprs/nn.h" + +namespace mgb { +namespace opr { + +namespace mixin { + +/*! + * \brief base class for the opr which can be tuning + */ +class AlgoChooserHelper : cg::OperatorNodeMixinBase { +public: + using ExecutionPolicy = megdnn::param::ExecutionPolicy; + using AlgorithmInfo = megdnn::detail::Algorithm::Info; + using AlgoChooserHook = + std::function; + + const ExecutionPolicy& execution_policy() const { + if (!m_policy_accessed) { + m_policy_accessed = true; + } + return m_policy; + } + + /*! + * \brief get current policy without marking it as having been accessed + * + * This is primarily used for getting current policy before calling + * set_execution_policy(). + */ + const ExecutionPolicy& execution_policy_transient() const { + return m_policy; + } + + /*! + * \brief modify execution policy + * + * Exception would be thrown if execution_policy() has been accessed, + * since it would influence cache and many other decisions. + */ + void set_execution_policy(const ExecutionPolicy& policy); + + /*! + * \brief register a hook to implement custom algo chooser + */ + void setup_algo_chooser(AlgoChooserHook&& func) { m_algo_chooser = func; } + AlgoChooserHook algo_chooser() const { return m_algo_chooser; } + +protected: + ~AlgoChooserHelper(); + + mutable bool m_policy_accessed = false; + ExecutionPolicy m_policy; + + AlgoChooserHook m_algo_chooser; + +}; +} // namespace mixin + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index a7d1aaacf..37272b208 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -12,9 +12,10 @@ #pragma once -#include "megbrain/opr/dnn/convolution.h" #include "megbrain/utils/hash_ct.h" #include "megbrain/utils/timer.h" +#include "megbrain/system.h" +#include "megbrain/comp_node.h" #include "megdnn/basic_types.h" #include "megdnn/oprs/nn.h" @@ -127,15 +128,15 @@ class TimedProfiler { static constexpr int arity_out = OprArityTrait::arity_out; static constexpr int arity = OprArityTrait::arity; - using ConvTensorShapes = std::array; + using TensorShapeArray = std::array; public: struct Param { char algo_name[128]; size_t workspace; - DTypeEnum dtypes[arity]; + megdnn::DTypeEnum dtypes[arity]; CompNode::Locator comp_node_loc; - ConvTensorShapes shapes; + TensorShapeArray shapes; typename Opr::Param opr_param; bool allow_weight_preprocess; -- GitLab