From a09a2b730dc280fd3860032fa673c89b82a87657 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 30 Aug 2021 18:32:30 +0800 Subject: [PATCH] fix(mgb/opr): fix fastrun workspace limit for imperative rt GitOrigin-RevId: bd69a82d4c2d8a36899bd3254dcef161065c2ac8 --- src/opr/impl/internal/megdnn_opr_wrapper.cpp | 54 ++++++++++++++++++- src/opr/impl/search_policy/algo_chooser.cpp | 12 ++++- .../opr/internal/megdnn_opr_wrapper.h | 16 ++++++ 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.cpp b/src/opr/impl/internal/megdnn_opr_wrapper.cpp index 792701549..ded6f0152 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.cpp +++ b/src/opr/impl/internal/megdnn_opr_wrapper.cpp @@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph size_t WorkspaceLimitGetter::get_workspace_limit( ComputingGraph* graph, CompNode cn, size_t old_limit) { if (graph->options().imperative_proxy_graph) { - return old_limit; + auto impl = WorkspaceLimitHook::get_impl(graph); + return impl(cn, old_limit); } if (!graph->options().seq_opt.enable_mem_reuse_alloc) return old_limit; @@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop( } } +/* ================== WorkspaceLimitHook ================== */ +MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook); + +#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA +void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { + mgb_assert(false); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() + const { + mgb_assert(false); +} + +void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, + GetWorkspaceLimitImpl /* impl */) { + mgb_assert(false); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( + ComputingGraph* /* graph */) { + mgb_assert(false); +} +#else +void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { + m_impl = std::move(impl); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() + const { + return m_impl; +} + +void WorkspaceLimitHook::set_impl(ComputingGraph* graph, + GetWorkspaceLimitImpl impl) { + mgb_assert(graph->options().imperative_proxy_graph); + auto maker = []() { return std::make_shared(); }; + graph->options() + .user_data.get_user_data_or_create(maker) + ->set_impl(impl); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( + ComputingGraph* graph) { + mgb_assert(graph->options().imperative_proxy_graph); + auto container = + graph->options().user_data.get_user_data(); + mgb_assert(container.second == 1); + return container.first[0]->get_impl(); +} +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 03e39fd8f..49b1b71c3 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -621,8 +621,11 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( if (prof.empty()) return {{}, rst}; + size_t workspace_limit = WorkspaceLimitGetter::get_workspace_limit( + owner_graph(), m_cn, m_execution_policy.workspace_limit); auto target_attr = extract_algo_attribute(selected_strategy); bool skip_by_negative = false; + bool skip_by_workspace = false; for (auto&& i : prof) { auto attr_of_algo = static_cast(i.attribute); bool contain_attr_all_positive = @@ -631,13 +634,18 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( static_cast(attr_of_algo & target_attr.second); if (contain_attr_all_positive) { if (!contain_attr_any_negative) { - Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); - return {algo_desc, rst}; + if (i.workspace <= workspace_limit) { + Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); + return {algo_desc, rst}; + } + skip_by_workspace = true; } else { skip_by_negative = true; } } } + if (skip_by_workspace) + return {}; std::string layouts_str = format_fixlayouts(m_fastrun_layouts, arity_in, arity_out); diff --git a/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h b/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h index cfb806048..44a2eb11c 100644 --- a/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h +++ b/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h @@ -316,6 +316,22 @@ protected: typename Super::NodeProp* do_make_node_prop() const override; }; +class WorkspaceLimitHook final : public UserDataContainer::UserData { + MGB_TYPEINFO_OBJ_DECL; + +public: + using GetWorkspaceLimitImpl = thin_function; + WorkspaceLimitHook() = default; + ~WorkspaceLimitHook() = default; + static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl); + static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph); + +private: + void set_impl(GetWorkspaceLimitImpl impl); + const GetWorkspaceLimitImpl& get_impl() const; + GetWorkspaceLimitImpl m_impl; +}; + } // namespace intl } // namespace opr } // namespace mgb -- GitLab