提交 d9a9d9d4 编写于 作者: M Megvii Engine Team

fix(imperative/fastrun): set workspace limit for imperative rt

GitOrigin-RevId: 474dc691a3ec20b09eac3e7b6682f622f6e56774
上级 a09a2b73
......@@ -15,6 +15,7 @@
#include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
......@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
const OpDef& opdef, const SmallVector<Tensor*>& inputs) {
SmallVector<LogicalTensorDesc> ret;
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true);
for (auto&& i : m_cur_opr->usable_output()) {
mgb_assert(i->dtype().valid() && i->comp_node().valid());
......@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor(
// get proxy opr
auto proxy = m_cur_opr;
auto get_workspace_size = [=](CompNode cn, size_t old_limit) {
size_t limit = 0;
for (auto&& var : workspaces) {
limit += var->dtype().size(var->shape().total_nr_elems());
}
return limit;
};
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size);
do_shape_infer(true);
size_t j = 0;
......@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr);
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true);
SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces;
......
......@@ -27,6 +27,11 @@ public:
static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error);
}
static size_t get_workspace_limit(CompNode cn, size_t old_limit) {
size_t free = cn.get_free_mem();
size_t lmt = cn.get_max_block_size_available();
return std::max(lmt, free);
}
/********************** Physical Tensor API **********************/
......
......@@ -273,6 +273,13 @@ public:
activate();
return m_mem_alloc->get_max_block_size_available();
}
size_t get_free_mem() override {
m_env.cuda_env().activate();
size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
return free;
}
#endif
Locator locator() override { return m_locator; }
......
......@@ -336,6 +336,8 @@ public:
size_t get_max_block_size_available() const {
return m_impl->get_max_block_size_available();
}
size_t get_free_mem() const { return m_impl->get_free_mem(); }
#endif
//! change to another stream on the same memory node
......@@ -519,6 +521,7 @@ protected:
}
virtual size_t get_used_memory() { return 0; }
virtual size_t get_max_block_size_available() { return 0; }
virtual size_t get_free_mem() { return 0; }
#endif
virtual Locator locator() = 0;
......
......@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
mgb_assert(false);
}
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */,
GetWorkspaceLimitImpl /* impl */) {
void WorkspaceLimitHook::set_impl(
ComputingGraph* /* graph */, GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}
......@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl = std::move(impl);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
return m_impl;
}
void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
GetWorkspaceLimitImpl impl) {
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl) {
mgb_assert(graph->options().imperative_proxy_graph);
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); };
graph->options()
......@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* graph) {
mgb_assert(graph->options().imperative_proxy_graph);
auto container =
graph->options().user_data.get_user_data<WorkspaceLimitHook>();
auto container = graph->options().user_data.get_user_data<WorkspaceLimitHook>();
mgb_assert(container.second == 1);
return container.first[0]->get_impl();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册