From 8585aa61addda75df402cdba3e0eb5eeea9cc009 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 12 Apr 2021 17:30:53 +0800 Subject: [PATCH] fix(mgb): fix fast run crash when profile heuristic strategy GitOrigin-RevId: 6046a2db0c532b33c78f20c4aa1aa7f7df1af0e4 --- src/core/impl/utils/persistent_cache.cpp | 133 ++++++++---------- .../include/megbrain/utils/persistent_cache.h | 31 ++++ src/opr/impl/search_policy/algo_chooser.cpp | 43 ++++-- .../megbrain/opr/search_policy/algo_chooser.h | 6 +- src/opr/test/dnn/convolution.cpp | 24 +++- 5 files changed, 146 insertions(+), 91 deletions(-) diff --git a/src/core/impl/utils/persistent_cache.cpp b/src/core/impl/utils/persistent_cache.cpp index 19cbf4c1f..4645fbd76 100644 --- a/src/core/impl/utils/persistent_cache.cpp +++ b/src/core/impl/utils/persistent_cache.cpp @@ -25,79 +25,9 @@ using namespace mgb; -namespace { - - class InMemoryPersistentCache final: public PersistentCache { - struct BlobStorage: public Blob { - std::unique_ptr data_refhold; - size_t hash = 0; - - BlobStorage& init_data_ref(const Blob &b) { - data_refhold = std::make_unique(b.size + 1); - memcpy(data_refhold.get(), b.ptr, b.size); - data_refhold.get()[b.size] = 0; // for C-string safety - ptr = data_refhold.get(); - size = b.size; - return *this; - } - - BlobStorage& init_hash() { - hash = XXHash{}.update(ptr, size).digest(); - return *this; - } - - bool operator == (const BlobStorage &rhs) const { - return size == rhs.size && !memcmp(ptr, rhs.ptr, size); - } - - struct Hash { - size_t operator() (const BlobStorage &b) const { - return b.hash; - } - }; - }; - std::unordered_map> - m_cache; - std::mutex m_mtx; - - Maybe get(const std::string& category, const Blob& key) override { - decltype(m_cache.begin()) iter0; - { - MGB_LOCK_GUARD(m_mtx); - iter0 = m_cache.find(category); - if (iter0 == m_cache.end()) - return None; - } - - BlobStorage key_storage; - key_storage.Blob::operator=(key); - key_storage.init_hash(); - - MGB_LOCK_GUARD(m_mtx); - - auto iter1 = iter0->second.find(key_storage); - if (iter1 == iter0->second.end()) - return None; - return iter1->second; - } - - void put(const std::string& category, const Blob& key, - const Blob& value) override { - BlobStorage key_storage; - key_storage.init_data_ref(key).init_hash(); - - MGB_LOCK_GUARD(m_mtx); - auto size0 = m_cache.size(); - m_cache[category][std::move(key_storage)].init_data_ref(value); - if (m_cache.size() > size0) { - mgb_log_debug("new cache category: %s", category.c_str()); - } - } - }; -} +// ================= PersistentCache ====================== std::shared_ptr PersistentCache::sm_impl = -std::make_shared(); + std::make_shared(); std::shared_ptr PersistentCache::set_impl( std::shared_ptr impl) { @@ -141,6 +71,65 @@ std::string PersistentCache::make_category_from_comp_node(CompNode comp_node) { } } +// ================= InMemoryPersistentCache ================== +using Blob = PersistentCache::Blob; +InMemoryPersistentCache::BlobStorage& +InMemoryPersistentCache::BlobStorage::init_data_ref(const Blob& b) { + data_refhold = std::make_unique(b.size + 1); + memcpy(data_refhold.get(), b.ptr, b.size); + data_refhold.get()[b.size] = 0; // for C-string safety + ptr = data_refhold.get(); + size = b.size; + return *this; +} + +InMemoryPersistentCache::BlobStorage& +InMemoryPersistentCache::BlobStorage::init_hash() { + hash = XXHash{}.update(ptr, size).digest(); + return *this; +} + +bool InMemoryPersistentCache::BlobStorage::operator==( + const BlobStorage& rhs) const { + return size == rhs.size && !memcmp(ptr, rhs.ptr, size); +} + +Maybe InMemoryPersistentCache::get(const std::string& category, + const Blob& key) { + decltype(m_cache.begin()) iter0; + { + MGB_LOCK_GUARD(m_mtx); + iter0 = m_cache.find(category); + if (iter0 == m_cache.end()) + return None; + } + + BlobStorage key_storage; + key_storage.Blob::operator=(key); + key_storage.init_hash(); + + MGB_LOCK_GUARD(m_mtx); + + auto iter1 = iter0->second.find(key_storage); + if (iter1 == iter0->second.end()) + return None; + return iter1->second; +} + +void InMemoryPersistentCache::put(const std::string& category, const Blob& key, + const Blob& value) { + BlobStorage key_storage; + key_storage.init_data_ref(key).init_hash(); + + MGB_LOCK_GUARD(m_mtx); + auto size0 = m_cache.size(); + m_cache[category][std::move(key_storage)].init_data_ref(value); + if (m_cache.size() > size0) { + mgb_log_debug("new cache category: %s", category.c_str()); + } +} + +// ================= AlgoChooserProfileCache ================== AlgoChooserProfileCache::AlgoChooserProfileCache( CompNode cn, const char *opr_type) { m_category = "profile:"; diff --git a/src/core/include/megbrain/utils/persistent_cache.h b/src/core/include/megbrain/utils/persistent_cache.h index ef84fa702..1e00bf765 100644 --- a/src/core/include/megbrain/utils/persistent_cache.h +++ b/src/core/include/megbrain/utils/persistent_cache.h @@ -55,6 +55,37 @@ namespace mgb { static std::string make_category_from_comp_node(CompNode comp_node); }; + /*! + * \brief persistent cache that keep in memory + * The implementation is thread safe. + */ + class InMemoryPersistentCache final : public PersistentCache { + struct BlobStorage : public PersistentCache::Blob { + std::unique_ptr data_refhold; + size_t hash = 0; + + BlobStorage& init_data_ref(const Blob& b); + + BlobStorage& init_hash(); + + bool operator==(const BlobStorage& rhs) const; + + struct Hash { + size_t operator()(const BlobStorage& b) const { return b.hash; } + }; + }; + + Maybe get(const std::string& category, const Blob& key) override; + void put(const std::string& category, const Blob& key, + const Blob& value) override; + + std::unordered_map< + std::string, + std::unordered_map> + m_cache; + std::mutex m_mtx; + }; + /*! * \brief proxy PersistentCache to be better suited for managing profiling * results of operator impl algorithms diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index f8109091b..b8c3d7be8 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -68,7 +68,6 @@ std::string format_fixlayouts( ret.append(", "); } ret.append(layouts[i].to_string() + " "); - ret.append(layouts[i].dtype.name()); } ret.append(") -> ("); for (size_t i = 0; i < arity_out; ++i) { @@ -76,7 +75,6 @@ std::string format_fixlayouts( ret.append(", "); } ret.append(layouts[i + arity_in].to_string() + " "); - ret.append(layouts[i + arity_in].dtype.name()); } return ret; } @@ -420,6 +418,7 @@ AlgoChooser::choose_by_profile(ExeContext& ctx, AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); }); } + typename AlgoChooser::ImplExecutionPolicy policy; ctx.construct_execution_policy(selected_strategy, policy); return policy; @@ -660,8 +659,28 @@ void AlgoChooser::ExeContext::construct_execution_policy( bool retrive_from_cache) const { if (!policy.algo.valid()) { if (retrive_from_cache) { - policy.algo = - get_profile_result_from_cache(selected_strategy).desc; + policy.algo = get_profile_result_from_cache(selected_strategy).desc; + if (!policy.algo.valid()) { + auto target_attr = + extract_algo_attribute_from_execution_strategy( + selected_strategy); + std::string layouts_str = + format_fixlayouts(m_layouts, arity_in, arity_out); + std::string msg = ssprintf( + "(mbg_opr : %s, layouts %s, with attribute(%s) and " + "without attribute(%s)", + m_base_mgb_opr->dyn_typeinfo()->name, + layouts_str.c_str(), + Algorithm::attribute_str(target_attr.first).c_str(), + Algorithm::attribute_str(target_attr.second).c_str()); + mgb_log_warn( + "No algo get from cache for %s. This may caused by " + "mismatch with model and cache file. ex. profiling " + "with version1, but inferencing on version2 or " + "profiling modelA but inferencing modelB", + msg.c_str()); + return; + } } else { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); @@ -673,10 +692,12 @@ void AlgoChooser::ExeContext::construct_execution_policy( attr.second), m_layouts) .desc; + mgb_assert(policy.algo.valid(), + "No algo found from heuristic with strategy %u and " + "workspace limit %zu", + static_cast(selected_strategy), + workspace_limit); } - mgb_assert(policy.algo.valid(), - "No algo found from cache or heuristic, maybe some error " - "occured"); } Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); @@ -697,9 +718,13 @@ void AlgoChooser::ExeContext::construct_execution_policy( sub_ctx.construct_execution_policy(selected_strategy, policy.sub_policy.back(), retrive_from_cache); + if (!policy.sub_policy.back().algo.valid()) { + // means sub_ctx.construct_execution_policy fails. clean up + // policy.algo and return + policy = {}; + return; + } }); - - return; } template diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index bb193e18a..8d20d7c89 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -140,9 +140,10 @@ public: * \brief construct execution policy from cache or heuristic. * * \param selected_strategy select algo which matched this strategy - * \param policy execution policy + * \param [out] policy execution policy * \param retrive_from_cache retrive algo from cache if set True, get * from heuristic otherwise. + * \note When contruction fail, the policy will be cleaned. */ void construct_execution_policy(ExecutionStrategy selected_strategy, ImplExecutionPolicy& policy, @@ -152,14 +153,13 @@ public: Maybe> construct_fake_preprocess_filter() const; }; - template + template friend class AlgoChooser; private: //! entrance for getting algorithm according to execution strategy static ImplExecutionPolicy get_policy(ExeContext& ctx); - //! profile and save to cache static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index 351a43694..32f89d7d6 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -30,7 +30,6 @@ #include using namespace mgb; - namespace { using Param = opr::Convolution::Param; @@ -354,21 +353,26 @@ TEST(TestOprDNN, ConvBiasExePolicy) { auto cn = CompNode::load("cpux"); + auto orig_impl = PersistentCache::set_impl( + std::make_shared()); + #if MGB_ENABLE_FASTRUN for (auto strategy : SmallVector{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE, - S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) { + S::PROFILE | S::HEURISTIC}) { #else for (auto strategy : SmallVector{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) { #endif + auto graph = ComputingGraph::make(); HostTensorGenerator<> gen; auto mkvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { return opr::TypeCvt::make( - opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), + opr::Host2DeviceCopy::make(*graph, gen(shp), cn) + .rename(name), dtype); }; @@ -388,7 +392,11 @@ TEST(TestOprDNN, ConvBiasExePolicy) { HostTensorND host_y; auto func = graph->compile({make_callback_copy(conv_bias, host_y)}); func->execute(); + + //! set a new cache + PersistentCache::set_impl(std::make_shared()); } + PersistentCache::set_impl(orig_impl); } TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { @@ -401,19 +409,21 @@ TEST(TestOprDNN, ConvBiasExePolicy_Quantized8Asym) { for (auto strategy : SmallVector{S::PROFILE, S::PROFILE | S::REPRODUCIBLE}) { - auto graph = ComputingGraph::make(); HostTensorGenerator<> gen; auto mkvar = [&](const char* name, const TensorShape& shp, const DType& dtype) { return opr::TypeCvt::make( - opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name), + opr::Host2DeviceCopy::make(*graph, gen(shp), cn) + .rename(name), dtype); }; - auto x = mkvar("x", {20, 50, 50, 16}, dtype::Quantized8Asymm(2.5f, static_cast(0))); - auto w = mkvar("w", {24, 3, 3, 16}, dtype::Quantized8Asymm(2.5f, static_cast(0))); + auto x = mkvar("x", {20, 50, 50, 16}, + dtype::Quantized8Asymm(2.5f, static_cast(0))); + auto w = mkvar("w", {24, 3, 3, 16}, + dtype::Quantized8Asymm(2.5f, static_cast(0))); auto bias = mkvar("bias", {1, 1, 1, 24}, dtype::QuantizedS32(6.25f)); param.nonlineMode = Param::NonlineMode::RELU; -- GitLab