From 7c09e41f18c8d00dc75cc265bf48a664f03f8903 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 Jan 2021 16:04:34 +0800 Subject: [PATCH] refactor(mgb): add circular dependency check GitOrigin-RevId: 01fdb8684be2c594d9b8d9a57d28528cf5412dc6 --- src/opr/impl/search_policy/algo_chooser.cpp | 105 +++++++++++++++--- .../megbrain/opr/search_policy/algo_chooser.h | 15 --- 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 45e9cd689..bbca06505 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -12,6 +12,8 @@ #include "megbrain/opr/search_policy/algo_chooser.h" #include +#include +#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/opr/search_policy/profiler.h" @@ -22,6 +24,7 @@ //! TODO: here has to be know some megdnn::opr when there is produced midout.h //! fix it if there is another graceful way. +#include "megdnn/opr_param_defs.h" #include "megdnn/oprs.h" #include "megdnn/oprs/base.h" #include "midout.h" @@ -78,6 +81,58 @@ std::string format_fixlayouts( return ret; } +/** + * \brief Check if the sub opr list has circular dependence. + */ +class CircularDepsChecker { + struct SearchItemStorage { + std::string data_hold; + size_t hash = 0; + + SearchItemStorage(const Algorithm::SearchItem& item) { + Algorithm::serialize_write_pod(item.opr_type, data_hold); + for (auto&& layout : item.layouts) { + data_hold += layout.serialize(); + } + data_hold += item.param; + } + + SearchItemStorage& init_hash() { + hash = XXHash64CT::hash(data_hold.data(), data_hold.size(), + 20201225); + return *this; + } + + bool operator==(const SearchItemStorage& rhs) const { + return data_hold == rhs.data_hold; + } + + struct Hash { + size_t operator()(const SearchItemStorage& s) const { + return s.hash; + } + }; + }; + std::unordered_set m_set; + +public: + void put(const megdnn::Algorithm::SearchItem& key) { + SearchItemStorage key_storage(key); + key_storage.init_hash(); + mgb_assert(m_set.find(key_storage) == m_set.end(), + "Circular dependency during flatten search space"); + auto ret = m_set.insert(std::move(key_storage)); + mgb_assert(ret.second); + } + void remove(const megdnn::Algorithm::SearchItem& key) { + SearchItemStorage key_storage(key); + key_storage.init_hash(); + auto&& iter = m_set.find(key_storage); + mgb_assert(iter != m_set.end()); + m_set.erase(iter); + } +}; + ///////////////// OprTypeTrait ///////////////////////////// template struct OprFromOprTypeTrait; @@ -176,14 +231,26 @@ typename opr::AlgoChooser::FixedTensorLayouts to_fixed_layouts( return ret; } -} // namespace - -namespace mgb { -namespace opr { - +/** + * flatten search space in postorder traversal + * The subopr search construct a search tree + * + * A + * / \ + * B1B2 C + * / \ + * D1D2D3 E + * We use postorder traverse the search tree. + * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A + */ template -std::vector -AlgoChooser::flatten_search_space(const ExeContext& ctx) { +std::vector flatten_search_space( + const typename opr::AlgoChooser::ExeContext& ctx, + CircularDepsChecker& checker) { + auto&& search_item = megdnn::Algorithm::SearchItem{ + OprTypeFromOprTrait::opr_type, ctx.param(), + to_layout_array(ctx.layouts())}; + checker.put(search_item); std::vector ret; for (auto algo_info : ctx.get_all_candidates()) { megdnn::Algorithm* algo = ctx.get_algorithm_from_desc(algo_info.desc); @@ -193,23 +260,29 @@ AlgoChooser::flatten_search_space(const ExeContext& ctx) { ctx.megdnn_opr()); FOREACH_OPR_TYPE_DISPATCH(sub_items, { - auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node()); + auto&& megdnn_opr = + opr::intl::create_megdnn_opr<_Opr>(ctx.comp_node()); megdnn_opr->param() = Algorithm::deserialize_read_pod( _item.param); - typename AlgoChooser<_Opr>::ExeContext sub_ctx( + typename opr::AlgoChooser<_Opr>::ExeContext sub_ctx( to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), _item.param, ctx.mgb_opr(), ctx.comp_node(), ctx.execution_policy(), ctx.allow_weight_preprocess()); - auto space = AlgoChooser<_Opr>::flatten_search_space(sub_ctx); + auto space = flatten_search_space<_Opr>(sub_ctx, checker); ret.insert(ret.end(), space.begin(), space.end()); }); } - ret.push_back({OprTypeFromOprTrait::opr_type, ctx.param(), - to_layout_array(ctx.layouts())}); + ret.push_back(search_item); + checker.remove(search_item); return ret; } +} // namespace + +namespace mgb { +namespace opr { + template void AlgoChooser::profile(ExeContext& ctx, bool require_reproducible) { if (ctx.get_profile_result_from_cache(require_reproducible).valid()) @@ -289,7 +362,9 @@ AlgoChooser::choose_by_profile(ExeContext& ctx, bool require_reproducible, } if (enable_update) { - auto&& search_items = flatten_search_space(ctx); + CircularDepsChecker circular_deps_checker; + auto&& search_items = + flatten_search_space(ctx, circular_deps_checker); FOREACH_OPR_TYPE_DISPATCH(search_items, { auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node()); megdnn_opr->param() = @@ -382,14 +457,12 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::get_policy( AlgoChooser::get_policy(ExeContext& ctx); \ template void AlgoChooser::profile( \ ExeContext& ctx, bool require_reproducible); \ - template std::vector \ - AlgoChooser::flatten_search_space(const ExeContext& ctx); \ template AlgoChooser::ImplExecutionPolicy \ AlgoChooser::choose_by_profile( \ ExeContext& ctx, bool require_reproducible, bool enable_update); \ template size_t AlgoChooser::setup_algo( \ const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ - const MGBOpr* mgb_opr, bool allow_weight_preprocess); + const MGBOpr* mgb_opr, bool allow_weight_preprocess); \ MGB_FOREACH_FASTRUN_OPR(INST) diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index a619f9bca..c32dc6d2c 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -159,21 +159,6 @@ private: bool require_reproducible, bool enable_update = true); - /** - * flatten search space in postorder traversal - * The subopr search construct a search tree - * - * A - * / \ - * B1B2 C - * / \ - * D1D2D3 E - * We use postorder traverse the search tree. - * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A - */ - static std::vector flatten_search_space( - const ExeContext& ctx); - public: /*! * \brief setup algorithm and return workspace size -- GitLab