diff --git a/src/opr/impl/search_policy/profiler.cpp b/src/opr/impl/search_policy/profiler.cpp index f3befdc2f9391c3e6db89d6e285c61f4d135310d..d2272abe40306db02cf69ae3171846f331d2dd64 100644 --- a/src/opr/impl/search_policy/profiler.cpp +++ b/src/opr/impl/search_policy/profiler.cpp @@ -154,6 +154,64 @@ double TimedProfiler::init_timeout_setting() { mgb::apply([&](const auto&... args) { return statement; }, \ std::tuple_cat(__VA_ARGS__)) +template +void TimedProfiler::preprocess(const TensorLayoutArray&, + const megdnn::SmallVector&, + intl::UniqPtrWithCN&, + megdnn::Workspace&, + std::array&, + std::array&, + PreprocessFilter&) { + // Opr is neither convbias nor convolution.This function do nothing. +} + +//! convbias +template <> +void TimedProfiler::preprocess( + const TensorLayoutArray& preprocessed_layout, + const SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt) { + if (!preprocessed_layout.empty()) { + auto&& pf = prep_flt; + pf.algorithm_id = nullptr; + pf.tensors.resize(flt_val.size()); + for (size_t i = 0; i < flt_val.size(); i++) { + pf.tensors[i] = flt_val[i].as_megdnn(); + } + APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace), + std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn(), + inp_val[2].as_megdnn()), + array_skip(layouts)); + } +} + +//! convolution +template <> +void TimedProfiler::preprocess( + const TensorLayoutArray& preprocessed_layout, + const megdnn::SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt) { + if (!preprocessed_layout.empty()) { + auto&& pf = prep_flt; + pf.algorithm_id = nullptr; + pf.tensors.resize(flt_val.size()); + for (size_t i = 0; i < flt_val.size(); i++) { + pf.tensors[i] = flt_val[i].as_megdnn(); + } + APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace), + std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn()), + array_skip<2>(layouts)); + } +} + template typename TimedProfiler::TResult TimedProfiler::prof_impl( const TParam& raw_param) { @@ -258,36 +316,8 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( } PreprocessFilter prep_flt; - if_constexpr()>([&](auto _) { - if (!preprocessed_layout.empty()) { - auto&& pf = _(prep_flt); - pf.algorithm_id = nullptr; - pf.tensors.resize(flt_val.size()); - for (size_t i = 0; i < flt_val.size(); i++) { - pf.tensors[i] = flt_val[i].as_megdnn(); - } - if_constexpr()>( - //! convbias - [&](auto __) { - APPLY(__(megdnn_opr) - ->exec_preprocess(args..., &pf, - mdn_workspace), - std::forward_as_tuple(layouts[0], - inp_val[1].as_megdnn(), - inp_val[2].as_megdnn()), - array_skip(layouts)); - }, - //! Convolution - [&](auto __) { - APPLY(__(megdnn_opr) - ->exec_preprocess(args..., &pf, - mdn_workspace), - std::forward_as_tuple(layouts[0], - inp_val[1].as_megdnn()), - array_skip(layouts)); - }); - } - }); + preprocess(preprocessed_layout, flt_val, megdnn_opr, mdn_workspace, layouts, + inp_val, prep_flt); RealTimer timer; auto ev_start = cn.create_event(CompNode::Event::NEED_TIMER), diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index 062e58671ee3cb1adf674ac7a4ffe18f31245e8a..d1bf34aa347a55f3dc89cc2c00850349e59be536 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -16,6 +16,8 @@ #include "megbrain/utils/timer.h" #include "megbrain/system.h" #include "megbrain/comp_node.h" +#include "megbrain/tensor.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megdnn/basic_types.h" #include "megdnn/oprs.h" @@ -149,6 +151,13 @@ private: static const double timeout_setting; static double init_timeout_setting(); + static void preprocess(const megdnn::TensorLayoutArray& preprocessed_layout, + const SmallVector& flt_val, + intl::UniqPtrWithCN& megdnn_opr, + megdnn::Workspace& mdn_workspace, + std::array& layouts, + std::array& inp_val, + PreprocessFilter& prep_flt); static TResult prof_impl(const TParam& raw_param); static void prof_init_device(const TParam& raw_param); };