提交 a09a2b73 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix fastrun workspace limit for imperative rt

GitOrigin-RevId: bd69a82d4c2d8a36899bd3254dcef161065c2ac8
上级 ac86d644
......@@ -254,7 +254,8 @@ WorkspaceLimitGetter::Impl* WorkspaceLimitGetter::get_impl(ComputingGraph* graph
size_t WorkspaceLimitGetter::get_workspace_limit(
ComputingGraph* graph, CompNode cn, size_t old_limit) {
if (graph->options().imperative_proxy_graph) {
return old_limit;
auto impl = WorkspaceLimitHook::get_impl(graph);
return impl(cn, old_limit);
}
if (!graph->options().seq_opt.enable_mem_reuse_alloc)
return old_limit;
......@@ -419,4 +420,55 @@ void MegDNNOprHolderBwdStaticInfer::mixin_update_node_prop(
}
}
/* ================== WorkspaceLimitHook ================== */
MGB_TYPEINFO_OBJ_IMPL(WorkspaceLimitHook);
#if MGB_BUILD_SLIM_SERVING && !MGB_CUDA
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
mgb_assert(false);
}
void WorkspaceLimitHook::set_impl(ComputingGraph* /* graph */,
GetWorkspaceLimitImpl /* impl */) {
mgb_assert(false);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* /* graph */) {
mgb_assert(false);
}
#else
void WorkspaceLimitHook::set_impl(GetWorkspaceLimitImpl impl) {
m_impl = std::move(impl);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl()
const {
return m_impl;
}
void WorkspaceLimitHook::set_impl(ComputingGraph* graph,
GetWorkspaceLimitImpl impl) {
mgb_assert(graph->options().imperative_proxy_graph);
auto maker = []() { return std::make_shared<WorkspaceLimitHook>(); };
graph->options()
.user_data.get_user_data_or_create<WorkspaceLimitHook>(maker)
->set_impl(impl);
}
const WorkspaceLimitHook::GetWorkspaceLimitImpl& WorkspaceLimitHook::get_impl(
ComputingGraph* graph) {
mgb_assert(graph->options().imperative_proxy_graph);
auto container =
graph->options().user_data.get_user_data<WorkspaceLimitHook>();
mgb_assert(container.second == 1);
return container.first[0]->get_impl();
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -621,8 +621,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
if (prof.empty())
return {{}, rst};
size_t workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto target_attr = extract_algo_attribute(selected_strategy);
bool skip_by_negative = false;
bool skip_by_workspace = false;
for (auto&& i : prof) {
auto attr_of_algo = static_cast<megdnn::Algorithm::Attribute>(i.attribute);
bool contain_attr_all_positive =
......@@ -631,13 +634,18 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
static_cast<bool>(attr_of_algo & target_attr.second);
if (contain_attr_all_positive) {
if (!contain_attr_any_negative) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return {algo_desc, rst};
if (i.workspace <= workspace_limit) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return {algo_desc, rst};
}
skip_by_workspace = true;
} else {
skip_by_negative = true;
}
}
}
if (skip_by_workspace)
return {};
std::string layouts_str =
format_fixlayouts<Opr>(m_fastrun_layouts, arity_in, arity_out);
......
......@@ -316,6 +316,22 @@ protected:
typename Super::NodeProp* do_make_node_prop() const override;
};
class WorkspaceLimitHook final : public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
public:
using GetWorkspaceLimitImpl = thin_function<size_t(CompNode, size_t)>;
WorkspaceLimitHook() = default;
~WorkspaceLimitHook() = default;
static void set_impl(ComputingGraph* graph, GetWorkspaceLimitImpl impl);
static const GetWorkspaceLimitImpl& get_impl(ComputingGraph* graph);
private:
void set_impl(GetWorkspaceLimitImpl impl);
const GetWorkspaceLimitImpl& get_impl() const;
GetWorkspaceLimitImpl m_impl;
};
} // namespace intl
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册