提交 88898e63 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mgb): replace if_constexpr with runtime function to avoid potential

bug

GitOrigin-RevId: 27fe093d506ab9536fe281ec45b061de5b0407e7
上级 25932352
......@@ -154,6 +154,64 @@ double TimedProfiler<Opr>::init_timeout_setting() {
mgb::apply([&](const auto&... args) { return statement; }, \
std::tuple_cat(__VA_ARGS__))
template <typename Opr>
void TimedProfiler<Opr>::preprocess(const TensorLayoutArray&,
const megdnn::SmallVector<DeviceTensorND>&,
intl::UniqPtrWithCN<Opr>&,
megdnn::Workspace&,
std::array<TensorLayout, arity>&,
std::array<DeviceTensorND, arity_in>&,
PreprocessFilter<Opr>&) {
// Opr is neither convbias nor convolution.This function do nothing.
}
//! convbias
template <>
void TimedProfiler<megdnn::ConvBias>::preprocess(
const TensorLayoutArray& preprocessed_layout,
const SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<megdnn::ConvBias>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<megdnn::ConvBias>& prep_flt) {
if (!preprocessed_layout.empty()) {
auto&& pf = prep_flt;
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace),
std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn(),
inp_val[2].as_megdnn()),
array_skip<arity_in - 1>(layouts));
}
}
//! convolution
template <>
void TimedProfiler<megdnn::ConvolutionForward>::preprocess(
const TensorLayoutArray& preprocessed_layout,
const megdnn::SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<megdnn::ConvolutionForward>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<megdnn::ConvolutionForward>& prep_flt) {
if (!preprocessed_layout.empty()) {
auto&& pf = prep_flt;
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
APPLY(megdnn_opr->exec_preprocess(args..., &pf, mdn_workspace),
std::forward_as_tuple(layouts[0], inp_val[1].as_megdnn()),
array_skip<2>(layouts));
}
}
template <typename Opr>
typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
const TParam& raw_param) {
......@@ -258,36 +316,8 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
}
PreprocessFilter<Opr> prep_flt;
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
if (!preprocessed_layout.empty()) {
auto&& pf = _(prep_flt);
pf.algorithm_id = nullptr;
pf.tensors.resize(flt_val.size());
for (size_t i = 0; i < flt_val.size(); i++) {
pf.tensors[i] = flt_val[i].as_megdnn();
}
if_constexpr<opr_contain_bias<Opr>()>(
//! convbias
[&](auto __) {
APPLY(__(megdnn_opr)
->exec_preprocess(args..., &pf,
mdn_workspace),
std::forward_as_tuple(layouts[0],
inp_val[1].as_megdnn(),
inp_val[2].as_megdnn()),
array_skip<arity_in - 1>(layouts));
},
//! Convolution
[&](auto __) {
APPLY(__(megdnn_opr)
->exec_preprocess(args..., &pf,
mdn_workspace),
std::forward_as_tuple(layouts[0],
inp_val[1].as_megdnn()),
array_skip<arity_in>(layouts));
});
}
});
preprocess(preprocessed_layout, flt_val, megdnn_opr, mdn_workspace, layouts,
inp_val, prep_flt);
RealTimer timer;
auto ev_start = cn.create_event(CompNode::Event::NEED_TIMER),
......
......@@ -16,6 +16,8 @@
#include "megbrain/utils/timer.h"
#include "megbrain/system.h"
#include "megbrain/comp_node.h"
#include "megbrain/tensor.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/basic_types.h"
#include "megdnn/oprs.h"
......@@ -149,6 +151,13 @@ private:
static const double timeout_setting;
static double init_timeout_setting();
static void preprocess(const megdnn::TensorLayoutArray& preprocessed_layout,
const SmallVector<DeviceTensorND>& flt_val,
intl::UniqPtrWithCN<Opr>& megdnn_opr,
megdnn::Workspace& mdn_workspace,
std::array<TensorLayout, arity>& layouts,
std::array<DeviceTensorND, arity_in>& inp_val,
PreprocessFilter<Opr>& prep_flt);
static TResult prof_impl(const TParam& raw_param);
static void prof_init_device(const TParam& raw_param);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册