From d9a9d9d49eb91e33ca7149a6e57a664911f50a38 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 7 Sep 2021 18:06:27 +0800 Subject: [PATCH] fix(imperative/fastrun): set workspace limit for imperative rt GitOrigin-RevId: 474dc691a3ec20b09eac3e7b6682f622f6e56774 --- imperative/src/impl/proxy_graph.cpp | 13 +++++++++++++ imperative/src/impl/proxy_graph.h | 5 +++++ src/core/impl/comp_node/cuda/comp_node.cpp | 7 +++++++ src/core/include/megbrain/comp_node.h | 3 +++ src/opr/impl/internal/megdnn_opr_wrapper.cpp | 16 ++++++---------- 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index e203524d3..402364e6c 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -15,6 +15,7 @@ #include "megbrain/graph/static_infer.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" @@ -509,6 +510,8 @@ SmallVector ProxyGraph::infer_output_attrs( const OpDef& opdef, const SmallVector& inputs) { SmallVector ret; CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); + ::mgb::opr::intl::WorkspaceLimitHook::set_impl( + m_graph.get(), ProxyGraph::get_workspace_limit); do_shape_infer(true); for (auto&& i : m_cur_opr->usable_output()) { mgb_assert(i->dtype().valid() && i->comp_node().valid()); @@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor( // get proxy opr auto proxy = m_cur_opr; + auto get_workspace_size = [=](CompNode cn, size_t old_limit) { + size_t limit = 0; + for (auto&& var : workspaces) { + limit += var->dtype().size(var->shape().total_nr_elems()); + } + return limit; + }; + ::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size); do_shape_infer(true); size_t j = 0; @@ -640,6 +651,8 @@ std::tuple, SmallVector> ProxyGraph:: const SmallVector& inputs_mems) { auto opr = get_proxy_opr(def, inputs_tensors); CUR_OPR_GUARD(opr); + ::mgb::opr::intl::WorkspaceLimitHook::set_impl( + m_graph.get(), ProxyGraph::get_workspace_limit); do_shape_infer(true); SmallVector outputs; SmallVector workspaces; diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index a403660bf..066f97bfd 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -27,6 +27,11 @@ public: static std::unique_ptr get_async_error() { return std::move(tm_async_error); } + static size_t get_workspace_limit(CompNode cn, size_t old_limit) { + size_t free = cn.get_free_mem(); + size_t lmt = cn.get_max_block_size_available(); + return std::max(lmt, free); + } /********************** Physical Tensor API **********************/ diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index e23514a36..c583e4fad 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -273,6 +273,13 @@ public: activate(); return m_mem_alloc->get_max_block_size_available(); } + + size_t get_free_mem() override { + m_env.cuda_env().activate(); + size_t tot, free; + MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot)); + return free; + } #endif Locator locator() override { return m_locator; } diff --git a/src/core/include/megbrain/comp_node.h b/src/core/include/megbrain/comp_node.h index 124206d28..3223d173e 100644 --- a/src/core/include/megbrain/comp_node.h +++ b/src/core/include/megbrain/comp_node.h @@ -336,6 +336,8 @@ public: size_t get_max_block_size_available() const { return m_impl->get_max_block_size_available(); } + + size_t get_free_mem() const { return m_impl->get_free_mem(); } #endif //! change to another stream on the same memory node @@ -519,6 +521,7 @@ protected: } virtual size_t get_used_memory() { return 0; } virtual size_t get_max_block_size_available() { return 0; } + virtual size_t get_free_mem() { return 0; } #endif virtual Locator locator() = 0; diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.cpp b/src/opr/impl/internal/megdnn_opr_wrapper.cpp index ded6f0152..a39499a31 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.cpp +++ b/src/opr/impl/internal/megdnn_opr_wrapper.cpp @@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { mgb_assert(false); } -const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() - const { +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const { mgb_assert(false); } -void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, - GetWorkspaceLimitImpl /* impl */) { +void WorkspaceLimitHook::set_impl( + ComputingGraph* /* graph */, GetWorkspaceLimitImpl /* impl */) { mgb_assert(false); } @@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { m_impl = std::move(impl); } -const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() - const { +const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const { return m_impl; } -void WorkspaceLimitHook::set_impl(ComputingGraph* graph, - GetWorkspaceLimitImpl impl) { +void WorkspaceLimitHook::set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl) { mgb_assert(graph->options().imperative_proxy_graph); auto maker = []() { return std::make_shared(); }; graph->options() @@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph, const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( ComputingGraph* graph) { mgb_assert(graph->options().imperative_proxy_graph); - auto container = - graph->options().user_data.get_user_data(); + auto container = graph->options().user_data.get_user_data(); mgb_assert(container.second == 1); return container.first[0]->get_impl(); } -- GitLab