diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.cpp b/src/opr/impl/internal/megdnn_opr_wrapper.cpp index 792701549d1195c8f336b40f2632193729ed11d6..ded6f015201cb74ed78f908988e993a6d7ac590f 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.cpp +++ b/src/opr/impl/internal/megdnn_opr_wrapper.cpp @@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph size_t WorkspaceLimitGetter::get_workspace_limit( ComputingGraph* graph, CompNode cn, size_t old_limit) { if (graph->options().imperative_proxy_graph) { - return old_limit; + auto impl = WorkspaceLimitHook::get_impl(graph); + return impl(cn, old_limit); } if (!graph->options().seq_opt.enable_mem_reuse_alloc) return old_limit; @@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop( } } +/* ================== WorkspaceLimitHook ================== */ +MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook); + +#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA +void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { + mgb_assert(false); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() + const { + mgb_assert(false); +} + +void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, + GetWorkspaceLimitImpl /* impl */) { + mgb_assert(false); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( + ComputingGraph* /* graph */) { + mgb_assert(false); +} +#else +void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { + m_impl = std::move(impl); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() + const { + return m_impl; +} + +void WorkspaceLimitHook::set_impl(ComputingGraph* graph, + GetWorkspaceLimitImpl impl) { + mgb_assert(graph->options().imperative_proxy_graph); + auto maker = []() { return std::make_shared(); }; + graph->options() + .user_data.get_user_data_or_create(maker) + ->set_impl(impl); +} + +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( + ComputingGraph* graph) { + mgb_assert(graph->options().imperative_proxy_graph); + auto container = + graph->options().user_data.get_user_data(); + mgb_assert(container.second == 1); + return container.first[0]->get_impl(); +} +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 03e39fd8fe695bacd73641cb842e811441de3e95..49b1b71c315f0f0acbda85539fe03cf079d2d5b1 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -621,8 +621,11 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( if (prof.empty()) return {{}, rst}; + size_t workspace_limit = WorkspaceLimitGetter::get_workspace_limit( + owner_graph(), m_cn, m_execution_policy.workspace_limit); auto target_attr = extract_algo_attribute(selected_strategy); bool skip_by_negative = false; + bool skip_by_workspace = false; for (auto&& i : prof) { auto attr_of_algo = static_cast(i.attribute); bool contain_attr_all_positive = @@ -631,13 +634,18 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( static_cast(attr_of_algo & target_attr.second); if (contain_attr_all_positive) { if (!contain_attr_any_negative) { - Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); - return {algo_desc, rst}; + if (i.workspace <= workspace_limit) { + Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); + return {algo_desc, rst}; + } + skip_by_workspace = true; } else { skip_by_negative = true; } } } + if (skip_by_workspace) + return {}; std::string layouts_str = format_fixlayouts(m_fastrun_layouts, arity_in, arity_out); diff --git a/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h b/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h index cfb8060480070d48f33b163579dc77634960a724..44a2eb11c56787b7fe7a1bfd5369f6ceee1e77e9 100644 --- a/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h +++ b/src/opr/include/megbrain/opr/internal/megdnn_opr_wrapper.h @@ -316,6 +316,22 @@ protected: typename Super::NodeProp* do_make_node_prop() const override; }; +class WorkspaceLimitHook final : public UserDataContainer::UserData { + MGB_TYPEINFO_OBJ_DECL; + +public: + using GetWorkspaceLimitImpl = thin_function; + WorkspaceLimitHook() = default; + ~WorkspaceLimitHook() = default; + static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl); + static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph); + +private: + void set_impl(GetWorkspaceLimitImpl impl); + const GetWorkspaceLimitImpl& get_impl() const; + GetWorkspaceLimitImpl m_impl; +}; + } // namespace intl } // namespace opr } // namespace mgb