diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index d00d972fe735b58da2bf72b5b02e48628fe5b8d4..2df4e95d2ec6d42c723dffe73a156d26b876e189 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -32,6 +32,7 @@ MIDOUT_DECL(megbrain_opr_convolution) MIDOUT_END(); #include "../internal/megdnn_opr_wrapper.inl" +#include "../internal/invoke.h" #include #include @@ -109,104 +110,74 @@ struct OprAttributeTrait { } }; +template +constexpr bool opr_supports_preprocess() { + return std::is_same::value || + std::is_same::value; +} + template struct OprArityTrait; -#define cb(x) (x) -#define cb_ref(x) (&(x)) -#define cb_dnn(x) ((x).as_megdnn()) +#define APPLY(statement, ...) \ + mgb::apply([&](const auto&... args) { return statement; }, \ + std::tuple_cat(__VA_ARGS__)) + +template +struct OprArityTraitTmpl { + static constexpr int arity_in = _arity_in; + static constexpr int arity_out = _arity_out; + static constexpr int arity = arity_in + arity_out; + using Algorithm = typename Opr::Algorithm; + using TensorLayoutArray = std::array; + + static size_t get_workspace_in_bytes(Opr* opr, Algorithm* algo, + const TensorLayoutArray& layouts) { + opr->execution_policy() = {algo}; + size_t workspace_size; + if_constexpr()>([&](auto) { + workspace_size = APPLY( + opr->get_workspace_in_bytes(args..., nullptr), layouts); + }, /* else */ [&](auto) { + workspace_size = + APPLY(opr->get_workspace_in_bytes(args...), layouts); + }); + return workspace_size; + } -#define WS_ARG_true ,nullptr -#define WS_ARG_false -#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \ - template <> \ - struct OprArityTrait<_Opr> { \ - static constexpr int arity_in = _in; \ - static constexpr int arity_out = _out; \ - static constexpr int arity = _in + _out; \ - using TensorLayoutArray = std::array; \ - static size_t get_workspace_in_bytes( \ - _Opr* opr, typename _Opr::Algorithm* algo, \ - const TensorLayoutArray& layouts) { \ - opr->execution_policy() = {algo}; \ - return opr->get_workspace_in_bytes( \ - LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \ - } \ - \ - static std::vector get_all_algorithms( \ - _Opr* opr, const TensorLayoutArray& layouts) { \ - return opr->get_all_algorithms(LAYOUTS(cb)); \ - } \ - \ - static typename _Opr::Algorithm* get_algorithm_heuristic( \ - _Opr* opr, const TensorLayoutArray& layouts, \ - size_t workspace_limit, bool reproducible) { \ - return opr->get_algorithm_heuristic(LAYOUTS(cb), workspace_limit, \ - reproducible); \ - } \ - \ - static void exec(_Opr* opr, const DeviceTensorND* inp_val, \ - const DeviceTensorND* out_val, \ - megdnn::Workspace& workspace) { \ - opr->exec(TENSORS(cb_dnn), workspace); \ - } \ + static void exec(Opr* opr, + const std::array& inp_val, + const std::array& out_val, + megdnn::Workspace& workspace) { + if_constexpr()>([&](auto) { + APPLY(opr->exec(args.as_megdnn()..., nullptr, workspace), inp_val, + out_val); + }, /* else */ [&](auto) { + APPLY(opr->exec(args.as_megdnn()..., workspace), inp_val, out_val); + }); } +}; + +#define INST_ARITY(_Opr, _in, _out) \ + template <> \ + struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; + +INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); +INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); +INST_ARITY(megdnn::Convolution3DForward, 2, 1); +INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); +INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); +INST_ARITY(megdnn::LocalShareForward, 2, 1); +INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); +INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); +INST_ARITY(megdnn::Convolution, 2, 1); +INST_ARITY(megdnn::DeformableConvForward, 4, 1); +INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); +INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); +INST_ARITY(megdnn::ConvBias, 4, 1); +INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); -#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]) -#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]) -#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false) -INST_ARITY_2_1(megdnn::ConvolutionBackwardData); -INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter); -INST_ARITY_2_1(megdnn::Convolution3DForward); -INST_ARITY_2_1(megdnn::Convolution3DBackwardData); -INST_ARITY_2_1(megdnn::Convolution3DBackwardFilter); -INST_ARITY_2_1(megdnn::LocalShareForward); -INST_ARITY_2_1(megdnn::LocalShareBackwardData); -INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); -#undef TENSORS -#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr -INST_ARITY(megdnn::Convolution, 2, 1, true); -#undef TENSORS -#undef LAYOUTS -#undef INST_ARITY_2_1 - -#define TENSORS(cb) \ - cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ - cb(out_val[0]) -#define LAYOUTS(cb) \ - cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \ - cb(layouts[4]) -#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false) -INST_ARITY_4_1(megdnn::DeformableConvForward); -INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter); -INST_ARITY_4_1(megdnn::BatchConvBiasForward); -#undef TENSORS -#define TENSORS(cb) \ - cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ - cb(out_val[0]), nullptr -INST_ARITY(megdnn::ConvBias, 4, 1, true); -#undef TENSORS -#undef LAYOUTS -#undef INST_ARITY_4_1 - -#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), \ - cb(inp_val[3]), cb(inp_val[4]), cb(out_val[0]), \ - cb(out_val[1]), cb(out_val[2]) -#define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), \ - cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \ - cb(layouts[6]), cb(layouts[7]) - -#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false) -INST_ARITY_5_3(megdnn::DeformableConvBackwardData); -#undef TENSORS -#undef LAYOUTS -#undef INST_ARITY_5_3 -#undef cb -#undef cb_ref -#undef cb_dnn #undef INST_ARITY -#undef WS_ARG_true -#undef WS_ARG_false // timeout delta to be added with fastest known algorithm for new algos constexpr double TIMEOUT_TOLERANCE = 2; @@ -343,8 +314,7 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( megdnn_opr->param() = param.opr_param; { typename Opr::Algorithm* algo = nullptr; - for (auto i : OprArityTrait::get_all_algorithms(megdnn_opr.get(), - layouts)) { + for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) { if (!strcmp(i->name(), param.algo_name)) { algo = i; break; @@ -368,7 +338,9 @@ typename TimedProfiler::TResult TimedProfiler::prof_impl( } // allocate input and output memory - DeviceTensorND inp_val[arity_in], out_val[arity_out], workspace; + std::array inp_val; + std::array out_val; + DeviceTensorND workspace; for (int i = 0; i < arity_in; ++i) { inp_val[i] .comp_node(cn) @@ -484,16 +456,17 @@ class AlgoChooser { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( opr->owner_graph(), opr->comp_node(), opr->execution_policy().workspace_limit); - return OprArityTrait::get_algorithm_heuristic( - m_megdnn_opr, m_layouts, workspace_limit, reproducible); + return APPLY(m_megdnn_opr->get_algorithm_heuristic( + args..., workspace_limit, reproducible), + m_layouts); } //! get all candidate algos, and the one choose_by_heuristic() is //! put first std::vector get_all_candidates() const { auto heu = choose_by_heuristic(); - auto&& ret = OprArityTrait::get_all_algorithms(m_megdnn_opr, - m_layouts); + auto&& ret = + APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts); bool found = false; for (size_t i = 0; i < ret.size(); ++i) { if (ret[i] == heu) {