diff --git a/dnn/include/megdnn/common.h b/dnn/include/megdnn/common.h index f0073c59b84a9e6fcc6b450c56a2053596ab39c0..09cb70a0886297bf3d66a68388c5a035048e2544 100644 --- a/dnn/include/megdnn/common.h +++ b/dnn/include/megdnn/common.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain_build_config.h" +#include "megdnn/oprs/base.h" #if MGB_ENABLE_GETENV #define MGB_GETENV ::std::getenv @@ -36,6 +37,11 @@ bool has_available_algo(Opr* opr, Args&&... args) { return !all_algos.empty(); } +template +bool has_no_naive_heuristic_algo(Opr* opr, Args&&... args) { + auto&& algo = opr->get_algorithm_info_heuristic(std::forward(args)...); + return !static_cast(algo.attribute & detail::Algorithm::Attribute::NAIVE); +} } // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp index d783a1f52a4f2da490d3a003a6b85fbd538e4a79..84d82e4505c5a496b57e5d172e8fbf878e32deef 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp @@ -220,6 +220,28 @@ struct MultiAlgoOprTrait; ::megdnn::has_available_algo(megdnn_opr, args...), array_layouts); \ MIDOUT_E \ } \ + static bool has_no_naive_heuristic_algo( \ + const VarNodeArray& i, const cg::OperatorNodeBase* opr_) { \ + MIDOUT_B( \ + midout_iv(MGB_HASH_STR(#_Opr)), \ + midout_iv(MGB_HASH_STR("has_no_naive_heuristic_algo"))) \ + auto&& opr = opr_->cast_final_safe<_Opr>(); \ + auto&& megdnn_opr = reinterpret_cast(opr.megdnn_opr()); \ + FixedTensorLayouts array_layouts; \ + size_t in = i.size() - 1; \ + for (size_t idx = 0; idx < in; idx++) { \ + const auto& v = i[idx]; \ + array_layouts[idx] = \ + TensorLayout{v->shape(), v->dtype(), v->format()}; \ + } \ + const auto& v = i[in]; \ + array_layouts[arity - 1] = \ + TensorLayout{v->shape(), v->dtype(), v->format()}; \ + return APPLY( \ + ::megdnn::has_no_naive_heuristic_algo(megdnn_opr, args...), \ + array_layouts); \ + MIDOUT_E \ + } \ }; INST(Convolution) INST(ConvBiasForward) @@ -365,6 +387,23 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) #undef cb } +bool has_no_naive_heuristic_algo( + const VarNodeArray& i, const cg::OperatorNodeBase* opr) { +#define cb(_Opr) \ + if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ + MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ + VarNodeArray _ = i; \ + _.emplace_back(opr->output(0)); \ + return MultiAlgoOprTrait<_Opr>::has_no_naive_heuristic_algo(_, opr); \ + } else + cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { + mgb_throw( + InternalError, "invalid multi-algo operator(got:%s)", + opr->dyn_typeinfo()->name); + } +#undef cb +} + bool has_opr_format(const cg::OperatorNodeBase* opr) { bool ret = false; #define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo(); diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h index 2ab6697aebbe0a93453554c8e22bb6f38206efb6..77b1d2927c4df5337e404740cad5d0227dd3e1b9 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.h +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.h @@ -27,6 +27,9 @@ namespace intl { bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); +bool has_no_naive_heuristic_algo( + const VarNodeArray& i, const cg::OperatorNodeBase* opr); + struct OprFormatInfo { opr::Convolution::Param::Format opr_format; struct TensorFormatsInfo { diff --git a/src/gopt/impl/global_layout_transform/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp index 48d7188a31b4fa1b20b2cf68e8fee6e90a8458a2..e303a43f72d720139eaa7a188759186bc2ee4ad5 100644 --- a/src/gopt/impl/global_layout_transform/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -331,7 +331,8 @@ float ProfilerImpl::profile_operator( opr::PoolingForward::typeinfo(), }; if (multi_algo_oprs.count(opr->dyn_typeinfo()) && - !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) + (!mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()) || + !mgb::gopt::intl::has_no_naive_heuristic_algo(new_inps, y->owner_opr()))) return PROFILE_TIME_OUT; if (!m_opr_filter(opr, y->owner_opr())) return PROFILE_TIME_OUT;