From 88898e63a534c8710ba57453a001e897ecdd92ea Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 5 Jul 2021 15:54:02 +0800 Subject: [PATCH] fix(mgb): replace if_constexpr with runtime function to avoid potential bug GitOrigin-RevId: 27fe093d506ab9536fe281ec45b061de5b0407e7 --- src/opr/impl/search_policy/profiler.cpp | 90 ++++++++++++------- .../megbrain/opr/search_policy/profiler.h | 9 ++ 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/src/opr/impl/search_policy/profiler.cpp b/src/opr/impl/search_policy/profiler.cpp index f3befdc2f..d2272abe4 100644 --- a/src/opr/impl/search_policy/profiler.cpp +++ b/src/opr/impl/search_policy/profiler.cpp @@ -154,6 +154,64 @@ double TimedProfiler::init_timeout_setting() { mgb::apply([&](const auto&... args) { return statement; }, \ std::tuple_cat(__VA_ARGS__)) +template +void TimedProfiler::preprocess(const TensorLayoutArray&, + const megdnn::SmallVector&, + intl::UniqPtrWithCN&, + megdnn::Workspace&, + std::array&, + std::array&, + PreprocessFilter&) { + // Opr is neither convbias nor convolution.This function do nothing. +} + +//! convbias +template <> +void TimedProfiler::preprocess( + const TensorLayoutArray& preprocessed_layout, + const SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt) { + if (!preprocessed_layout.empty()) { + auto&& pf = prep_flt; + pf.algorithm_id = nullptr; + pf.tensors.resize(flt_val.size()); + for (size_t i = 0; i < flt_val.size(); i++) { + pf.tensors[i] = flt_val[i].as_megdnn(); + } + APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace), + std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn(), + inp_val[2].as_megdnn()), + array_skip(layouts)); + } +} + +//! convolution +template <> +void TimedProfiler::preprocess( + const TensorLayoutArray& preprocessed_layout, + const megdnn::SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt) { + if (!preprocessed_layout.empty()) { + auto&& pf = prep_flt; + pf.algorithm_id = nullptr; + pf.tensors.resize(flt_val.size()); + for (size_t i = 0; i < flt_val.size(); i++) { + pf.tensors[i] = flt_val[i].as_megdnn(); + } + APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace), + std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn()), + array_skip<2>(layouts)); + } +} + template typename TimedProfiler::TResult TimedProfiler::prof_impl( const TParam& raw_param) { @@ -258,36 +316,8 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( } PreprocessFilter prep_flt; - if_constexpr()>([&](auto _) { - if (!preprocessed_layout.empty()) { - auto&& pf = _(prep_flt); - pf.algorithm_id = nullptr; - pf.tensors.resize(flt_val.size()); - for (size_t i = 0; i < flt_val.size(); i++) { - pf.tensors[i] = flt_val[i].as_megdnn(); - } - if_constexpr()>( - //! convbias - [&](auto __) { - APPLY(__(megdnn_opr) - ->exec_preprocess(args..., &pf, - mdn_workspace), - std::forward_as_tuple(layouts[0], - inp_val[1].as_megdnn(), - inp_val[2].as_megdnn()), - array_skip(layouts)); - }, - //! Convolution - [&](auto __) { - APPLY(__(megdnn_opr) - ->exec_preprocess(args..., &pf, - mdn_workspace), - std::forward_as_tuple(layouts[0], - inp_val[1].as_megdnn()), - array_skip(layouts)); - }); - } - }); + preprocess(preprocessed_layout, flt_val, megdnn_opr, mdn_workspace, layouts, + inp_val, prep_flt); RealTimer timer; auto ev_start = cn.create_event(CompNode::Event::NEED_TIMER), diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index 062e58671..d1bf34aa3 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -16,6 +16,8 @@ #include "megbrain/utils/timer.h" #include "megbrain/system.h" #include "megbrain/comp_node.h" +#include "megbrain/tensor.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megdnn/basic_types.h" #include "megdnn/oprs.h" @@ -149,6 +151,13 @@ private: static const double timeout_setting; static double init_timeout_setting(); + static void preprocess(const megdnn::TensorLayoutArray& preprocessed_layout, + const SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt); static TResult prof_impl(const TParam& raw_param); static void prof_init_device(const TParam& raw_param); }; -- GitLab