提交 fe3ee3cd 编写于 作者: M Megvii Engine Team

refactor(opr): refactor OprArityTrait

GitOrigin-RevId: fa065cde4ea223dcc394d2a73e14d978384f85de
上级 91efd67d
......@@ -32,6 +32,7 @@ MIDOUT_DECL(megbrain_opr_convolution)
MIDOUT_END();
#include "../internal/megdnn_opr_wrapper.inl"
#include "../internal/invoke.h"
#include <array>
#include <chrono>
......@@ -109,104 +110,74 @@ struct OprAttributeTrait<opr::ConvBias> {
}
};
template <typename Opr>
constexpr bool opr_supports_preprocess() {
return std::is_same<Opr, megdnn::ConvolutionForward>::value ||
std::is_same<Opr, megdnn::ConvBias>::value;
}
template <typename Opr>
struct OprArityTrait;
#define cb(x) (x)
#define cb_ref(x) (&(x))
#define cb_dnn(x) ((x).as_megdnn())
#define APPLY(statement, ...) \
mgb::apply([&](const auto&... args) { return statement; }, \
std::tuple_cat(__VA_ARGS__))
template <typename Opr, int _arity_in, int _arity_out>
struct OprArityTraitTmpl {
static constexpr int arity_in = _arity_in;
static constexpr int arity_out = _arity_out;
static constexpr int arity = arity_in + arity_out;
using Algorithm = typename Opr::Algorithm;
using TensorLayoutArray = std::array<TensorLayout, arity>;
static size_t get_workspace_in_bytes(Opr* opr, Algorithm* algo,
const TensorLayoutArray& layouts) {
opr->execution_policy() = {algo};
size_t workspace_size;
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) {
workspace_size = APPLY(
opr->get_workspace_in_bytes(args..., nullptr), layouts);
}, /* else */ [&](auto) {
workspace_size =
APPLY(opr->get_workspace_in_bytes(args...), layouts);
});
return workspace_size;
}
#define WS_ARG_true ,nullptr
#define WS_ARG_false
#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \
template <> \
struct OprArityTrait<_Opr> { \
static constexpr int arity_in = _in; \
static constexpr int arity_out = _out; \
static constexpr int arity = _in + _out; \
using TensorLayoutArray = std::array<TensorLayout, arity>; \
static size_t get_workspace_in_bytes( \
_Opr* opr, typename _Opr::Algorithm* algo, \
const TensorLayoutArray& layouts) { \
opr->execution_policy() = {algo}; \
return opr->get_workspace_in_bytes( \
LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \
} \
\
static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \
_Opr* opr, const TensorLayoutArray& layouts) { \
return opr->get_all_algorithms(LAYOUTS(cb)); \
} \
\
static typename _Opr::Algorithm* get_algorithm_heuristic( \
_Opr* opr, const TensorLayoutArray& layouts, \
size_t workspace_limit, bool reproducible) { \
return opr->get_algorithm_heuristic(LAYOUTS(cb), workspace_limit, \
reproducible); \
} \
\
static void exec(_Opr* opr, const DeviceTensorND* inp_val, \
const DeviceTensorND* out_val, \
megdnn::Workspace& workspace) { \
opr->exec(TENSORS(cb_dnn), workspace); \
} \
static void exec(Opr* opr,
const std::array<DeviceTensorND, arity_in>& inp_val,
const std::array<DeviceTensorND, arity_out>& out_val,
megdnn::Workspace& workspace) {
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) {
APPLY(opr->exec(args.as_megdnn()..., nullptr, workspace), inp_val,
out_val);
}, /* else */ [&](auto) {
APPLY(opr->exec(args.as_megdnn()..., workspace), inp_val, out_val);
});
}
};
#define INST_ARITY(_Opr, _in, _out) \
template <> \
struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {};
INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1);
INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1);
INST_ARITY(megdnn::Convolution3DForward, 2, 1);
INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1);
INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1);
INST_ARITY(megdnn::LocalShareForward, 2, 1);
INST_ARITY(megdnn::LocalShareBackwardData, 2, 1);
INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1);
INST_ARITY(megdnn::Convolution, 2, 1);
INST_ARITY(megdnn::DeformableConvForward, 4, 1);
INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1);
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1);
INST_ARITY(megdnn::ConvBias, 4, 1);
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3);
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0])
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2])
#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false)
INST_ARITY_2_1(megdnn::ConvolutionBackwardData);
INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter);
INST_ARITY_2_1(megdnn::Convolution3DForward);
INST_ARITY_2_1(megdnn::Convolution3DBackwardData);
INST_ARITY_2_1(megdnn::Convolution3DBackwardFilter);
INST_ARITY_2_1(megdnn::LocalShareForward);
INST_ARITY_2_1(megdnn::LocalShareBackwardData);
INST_ARITY_2_1(megdnn::LocalShareBackwardFilter);
#undef TENSORS
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr
INST_ARITY(megdnn::Convolution, 2, 1, true);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_2_1
#define TENSORS(cb) \
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \
cb(out_val[0])
#define LAYOUTS(cb) \
cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \
cb(layouts[4])
#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false)
INST_ARITY_4_1(megdnn::DeformableConvForward);
INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter);
INST_ARITY_4_1(megdnn::BatchConvBiasForward);
#undef TENSORS
#define TENSORS(cb) \
cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \
cb(out_val[0]), nullptr
INST_ARITY(megdnn::ConvBias, 4, 1, true);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_4_1
#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), \
cb(inp_val[3]), cb(inp_val[4]), cb(out_val[0]), \
cb(out_val[1]), cb(out_val[2])
#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), \
cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \
cb(layouts[6]), cb(layouts[7])
#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false)
INST_ARITY_5_3(megdnn::DeformableConvBackwardData);
#undef TENSORS
#undef LAYOUTS
#undef INST_ARITY_5_3
#undef cb
#undef cb_ref
#undef cb_dnn
#undef INST_ARITY
#undef WS_ARG_true
#undef WS_ARG_false
// timeout delta to be added with fastest known algorithm for new algos
constexpr double TIMEOUT_TOLERANCE = 2;
......@@ -343,8 +314,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
megdnn_opr->param() = param.opr_param;
{
typename Opr::Algorithm* algo = nullptr;
for (auto i : OprArityTrait<Opr>::get_all_algorithms(megdnn_opr.get(),
layouts)) {
for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) {
if (!strcmp(i->name(), param.algo_name)) {
algo = i;
break;
......@@ -368,7 +338,9 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
}
// allocate input and output memory
DeviceTensorND inp_val[arity_in], out_val[arity_out], workspace;
std::array<DeviceTensorND, arity_in> inp_val;
std::array<DeviceTensorND, arity_out> out_val;
DeviceTensorND workspace;
for (int i = 0; i < arity_in; ++i) {
inp_val[i]
.comp_node(cn)
......@@ -484,16 +456,17 @@ class AlgoChooser {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit);
return OprArityTrait<Opr>::get_algorithm_heuristic(
m_megdnn_opr, m_layouts, workspace_limit, reproducible);
return APPLY(m_megdnn_opr->get_algorithm_heuristic(
args..., workspace_limit, reproducible),
m_layouts);
}
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std::vector<ImplAlgo> get_all_candidates() const {
auto heu = choose_by_heuristic();
auto&& ret = OprArityTrait<Opr>::get_all_algorithms(m_megdnn_opr,
m_layouts);
auto&& ret =
APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts);
bool found = false;
for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i] == heu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册