提交 d9a9d9d4 编写于 作者: M Megvii Engine Team

fix(imperative/fastrun): set workspace limit for imperative rt

GitOrigin-RevId: 474dc691a3ec20b09eac3e7b6682f622f6e56774
上级 a09a2b73
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/graph/static_infer.h" #include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
...@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs( ...@@ -509,6 +510,8 @@ SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
const OpDef& opdef, const SmallVector<Tensor*>& inputs) { const OpDef& opdef, const SmallVector<Tensor*>& inputs) {
SmallVector<LogicalTensorDesc> ret; SmallVector<LogicalTensorDesc> ret;
CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true); do_shape_infer(true);
for (auto&& i : m_cur_opr->usable_output()) { for (auto&& i : m_cur_opr->usable_output()) {
mgb_assert(i->dtype().valid() && i->comp_node().valid()); mgb_assert(i->dtype().valid() && i->comp_node().valid());
...@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor( ...@@ -547,6 +550,14 @@ void ProxyGraph::init_output_tensor(
// get proxy opr // get proxy opr
auto proxy = m_cur_opr; auto proxy = m_cur_opr;
auto get_workspace_size = [=](CompNode cn, size_t old_limit) {
size_t limit = 0;
for (auto&& var : workspaces) {
limit += var->dtype().size(var->shape().total_nr_elems());
}
return limit;
};
::mgb::opr::intl::WorkspaceLimitHook::set_impl(m_graph.get(), get_workspace_size);
do_shape_infer(true); do_shape_infer(true);
size_t j = 0; size_t j = 0;
...@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: ...@@ -640,6 +651,8 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
const SmallVector<MemoryDesc>& inputs_mems) { const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors); auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr); CUR_OPR_GUARD(opr);
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true); do_shape_infer(true);
SmallVector<MemoryDesc> outputs; SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces; SmallVector<MemoryDesc> workspaces;
......
...@@ -27,6 +27,11 @@ public: ...@@ -27,6 +27,11 @@ public:
static std::unique_ptr<MegBrainError> get_async_error() { static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error); return std::move(tm_async_error);
} }
static size_t get_workspace_limit(CompNode cn, size_t old_limit) {
size_t free = cn.get_free_mem();
size_t lmt = cn.get_max_block_size_available();
return std::max(lmt, free);
}
/********************** Physical Tensor API **********************/ /********************** Physical Tensor API **********************/
......
...@@ -273,6 +273,13 @@ public: ...@@ -273,6 +273,13 @@ public:
activate(); activate();
return m_mem_alloc->get_max_block_size_available(); return m_mem_alloc->get_max_block_size_available();
} }
size_t get_free_mem() override {
m_env.cuda_env().activate();
size_t tot, free;
MGB_CUDA_CHECK(cudaMemGetInfo(&free, &tot));
return free;
}
#endif #endif
Locator locator() override { return m_locator; } Locator locator() override { return m_locator; }
......
...@@ -336,6 +336,8 @@ public: ...@@ -336,6 +336,8 @@ public:
size_t get_max_block_size_available() const { size_t get_max_block_size_available() const {
return m_impl->get_max_block_size_available(); return m_impl->get_max_block_size_available();
} }
size_t get_free_mem() const { return m_impl->get_free_mem(); }
#endif #endif
//! change to another stream on the same memory node //! change to another stream on the same memory node
...@@ -519,6 +521,7 @@ protected: ...@@ -519,6 +521,7 @@ protected:
} }
virtual size_t get_used_memory() { return 0; } virtual size_t get_used_memory() { return 0; }
virtual size_t get_max_block_size_available() { return 0; } virtual size_t get_max_block_size_available() { return 0; }
virtual size_t get_free_mem() { return 0; }
#endif #endif
virtual Locator locator() = 0; virtual Locator locator() = 0;
......
...@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) { ...@@ -428,13 +428,12 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false); mgb_assert(false);
} }
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
const {
mgb_assert(false); mgb_assert(false);
} }
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */, void WorkspaceLimitHook::set_impl(
GetWorkspaceLimitImpl /* impl */) { ComputingGraph* /* graph */, GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false); mgb_assert(false);
} }
...@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) { ...@@ -447,13 +446,11 @@ void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl = std::move(impl); m_impl = std::move(impl);
} }
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl() const {
const {
return m_impl; return m_impl;
} }
void WorkspaceLimitHook::set_impl(ComputingGraph* graph, void WorkspaceLimitHook::set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl) {
GetWorkspaceLimitImpl impl) {
mgb_assert(graph->options().imperative_proxy_graph); mgb_assert(graph->options().imperative_proxy_graph);
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); }; auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); };
graph->options() graph->options()
...@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph, ...@@ -464,8 +461,7 @@ void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl( const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* graph) { ComputingGraph* graph) {
mgb_assert(graph->options().imperative_proxy_graph); mgb_assert(graph->options().imperative_proxy_graph);
auto container = auto container = graph->options().user_data.get_user_data<WorkspaceLimitHook>();
graph->options().user_data.get_user_data<WorkspaceLimitHook>();
mgb_assert(container.second == 1); mgb_assert(container.second == 1);
return container.first[0]->get_impl(); return container.first[0]->get_impl();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册