From 93310c0e4b34c371059d401f5e51714368f5cb2c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Dec 2021 14:20:34 +0800 Subject: [PATCH] fix(mgb/gopt): fix cpu global layout transform fastrun error GitOrigin-RevId: ea254297e5e46134490c9aa9f1e98861493ba015 --- dnn/include/megdnn/common.h | 6 +++ .../opr_format_modifier.cpp | 39 +++++++++++++++++++ .../opr_format_modifier.h | 3 ++ .../global_layout_transform/profiler_impl.cpp | 3 +- 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/dnn/include/megdnn/common.h b/dnn/include/megdnn/common.h index f0073c59b..09cb70a08 100644 --- a/dnn/include/megdnn/common.h +++ b/dnn/include/megdnn/common.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain_build_config.h" +#include "megdnn/oprs/base.h" #if MGB_ENABLE_GETENV #define MGB_GETENV ::std::getenv @@ -36,6 +37,11 @@ bool has_available_algo(Opr* opr, Args&&... args) { return !all_algos.empty(); } +template +bool has_no_naive_heuristic_algo(Opr* opr, Args&&... args) { + auto&& algo = opr->get_algorithm_info_heuristic(std::forward(args)...); + return !static_cast(algo.attribute & detail::Algorithm::Attribute::NAIVE); +} } // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp index d783a1f52..84d82e450 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp @@ -220,6 +220,28 @@ struct MultiAlgoOprTrait; ::megdnn::has_available_algo(megdnn_opr, args...), array_layouts); \ MIDOUT_E \ } \ + static bool has_no_naive_heuristic_algo( \ + const VarNodeArray& i, const cg::OperatorNodeBase* opr_) { \ + MIDOUT_B( \ + midout_iv(MGB_HASH_STR(#_Opr)), \ + midout_iv(MGB_HASH_STR("has_no_naive_heuristic_algo"))) \ + auto&& opr = opr_->cast_final_safe<_Opr>(); \ + auto&& megdnn_opr = reinterpret_cast(opr.megdnn_opr()); \ + FixedTensorLayouts array_layouts; \ + size_t in = i.size() - 1; \ + for (size_t idx = 0; idx < in; idx++) { \ + const auto& v = i[idx]; \ + array_layouts[idx] = \ + TensorLayout{v->shape(), v->dtype(), v->format()}; \ + } \ + const auto& v = i[in]; \ + array_layouts[arity - 1] = \ + TensorLayout{v->shape(), v->dtype(), v->format()}; \ + return APPLY( \ + ::megdnn::has_no_naive_heuristic_algo(megdnn_opr, args...), \ + array_layouts); \ + MIDOUT_E \ + } \ }; INST(Convolution) INST(ConvBiasForward) @@ -365,6 +387,23 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) #undef cb } +bool has_no_naive_heuristic_algo( + const VarNodeArray& i, const cg::OperatorNodeBase* opr) { +#define cb(_Opr) \ + if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ + MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ + VarNodeArray _ = i; \ + _.emplace_back(opr->output(0)); \ + return MultiAlgoOprTrait<_Opr>::has_no_naive_heuristic_algo(_, opr); \ + } else + cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { + mgb_throw( + InternalError, "invalid multi-algo operator(got:%s)", + opr->dyn_typeinfo()->name); + } +#undef cb +} + bool has_opr_format(const cg::OperatorNodeBase* opr) { bool ret = false; #define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h index 2ab6697ae..77b1d2927 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.h +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.h @@ -27,6 +27,9 @@ namespace intl { bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); +bool has_no_naive_heuristic_algo( + const VarNodeArray& i, const cg::OperatorNodeBase* opr); + struct OprFormatInfo { opr::Convolution::Param::Format opr_format; struct TensorFormatsInfo { diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index 48d7188a3..e303a43f7 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -331,7 +331,8 @@ float ProfilerImpl::profile_operator( opr::PoolingForward::typeinfo(), }; if (multi_algo_oprs.count(opr->dyn_typeinfo()) && - !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) + (!mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()) || + !mgb::gopt::intl::has_no_naive_heuristic_algo(new_inps, y->owner_opr()))) return PROFILE_TIME_OUT; if (!m_opr_filter(opr, y->owner_opr())) return PROFILE_TIME_OUT; -- GitLab