From 5063a2063b2d342fb2d1680c299531ab5577fa29 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 7 Jan 2021 00:32:11 +0800 Subject: [PATCH] feat(mgb): add fastrun support for matmul/batchedmatrixmul GitOrigin-RevId: a48ea9bff6d852ffecbbe96982e83a0738c0d366 --- src/gopt/impl/inference.cpp | 45 ++++----- src/opr/impl/blas.cpp | 99 +++++++++++++++++-- src/opr/impl/blas.oprdecl | 19 ++++ src/opr/impl/blas.sereg.h | 67 ++++++++++++- src/opr/impl/dnn/convolution.cpp | 1 - src/opr/impl/dnn/dnn.sereg.h | 4 +- src/opr/impl/search_policy/algo_chooser.cpp | 1 + src/opr/include/megbrain/opr/blas.h | 87 ++++++++-------- .../megbrain/opr/search_policy/algo_chooser.h | 1 + .../megbrain/opr/search_policy/profiler.h | 37 ++++--- src/opr/test/blas.cpp | 41 +++++++- 11 files changed, 302 insertions(+), 100 deletions(-) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 0acbe615..881a7c54 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -16,6 +16,7 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/search_policy/algo_chooser_helper.h" +#include "megbrain/opr/search_policy/profiler.h" #include "megbrain/utils/shared_set.h" #include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/opr/basic_arith.h" @@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, } // anonymous namespace -#define MGB_FOREACH_FASTRUN_OPR(cb) \ - cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \ - cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \ - cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \ - cb(LocalShareForward), cb(LocalShareBackwardData), \ - cb(LocalShareBackwardFilter), cb(DeformableConvForward), \ - cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ - cb(BatchConvBiasForward), - void gopt::modify_opr_algo_strategy_inplace( const VarNodeArrayView& dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { @@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace( modifiers = { #define CONV(t) \ {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier, \ - std::placeholders::_1, strategy)} + std::placeholders::_1, strategy)}, MGB_FOREACH_FASTRUN_OPR(CONV) #undef CONV }; @@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace( static const ThinHashMap modifiers = { #define CONV(t) \ - {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier} + {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier}, MGB_FOREACH_FASTRUN_OPR(CONV) #undef CONV }; @@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace( dep_iter.add(i); } } -#undef MGB_FOREACH_FASTRUN_OPR /* ================ ParamRedistributePass ================ */ const char* ParamRedistributePass::name() const { @@ -790,8 +781,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( new_inp[1]->name().c_str(), new_inp[1]->owner_opr()->name().c_str()); auto new_deconv_opr = opr::ConvolutionBackwardData::make( - new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(), - deconv_opr.config()); + new_inp[0], new_inp[1], new_param, + deconv_opr.execution_policy(), deconv_opr.config()); return new_deconv_opr.node()->owner_opr(); }; @@ -813,20 +804,20 @@ std::unique_ptr ConvertF32ToF16Pass::make( new_inp[1]->owner_opr()->name().c_str()); if(opr->input().size() == 2) { auto new_conv_opr = opr::ConvBias::make( - new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(), - convbias_opr.config()); + new_inp[0], new_inp[1], new_param, + convbias_opr.execution_policy(), convbias_opr.config()); return new_conv_opr.node()->owner_opr(); } else if(opr->input().size() == 3) { auto new_conv_opr = opr::ConvBias::make( - new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(), - convbias_opr.config()); + new_inp[0], new_inp[1], new_inp[2], new_param, + convbias_opr.execution_policy(), convbias_opr.config()); return new_conv_opr.node()->owner_opr(); } else { mgb_assert(opr->input().size() == 4, "invalid input size %zu", opr->input().size()); auto new_conv_opr = opr::ConvBias::make( - new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(), - convbias_opr.config()); + new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, + convbias_opr.execution_policy(), convbias_opr.config()); return new_conv_opr.node()->owner_opr(); } }; @@ -841,7 +832,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( megdnn::param::MatrixMul::ComputeMode::FLOAT32; } auto new_matmul_opr = opr::MatrixMul::make( - new_inp[0], new_inp[1], new_param, matmul_opr.config()); + new_inp[0], new_inp[1], new_param, + matmul_opr.execution_policy(), matmul_opr.config()); return new_matmul_opr.node()->owner_opr(); }; @@ -864,7 +856,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( new_inp[1]->name().c_str(), new_inp[1]->owner_opr()->name().c_str()); auto new_matmul_opr = opr::BatchedMatrixMul::make( - new_inp[0], new_inp[1], new_param, matmul_opr.config()); + new_inp[0], new_inp[1], new_param, + matmul_opr.execution_policy(), matmul_opr.config()); return new_matmul_opr.node()->owner_opr(); }; @@ -915,8 +908,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( new_mat->owner_opr()->input(0)->dtype() == dtype::Float32()) new_mat = new_mat->owner_opr()->input(0); else - new_mat = - opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); + new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) + .node(); } SymbolVar new_warp; if (new_inp.size() == 3) { @@ -944,8 +937,8 @@ std::unique_ptr ConvertF32ToF16Pass::make( new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) new_map = new_map->owner_opr()->input(0); else - new_map = - opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); + new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}) + .node(); } SymbolVar new_remap; diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index 76921669..5a5c1e56 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -18,15 +18,41 @@ #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/search_policy/algo_chooser.h" +#include "megbrain/opr/search_policy/profiler.h" + #include "./internal/megdnn_opr_wrapper.inl" +#include "./search_policy/workspace_need_limit_getter.inl" +#include "megdnn/oprs/linalg.h" using namespace mgb; using namespace opr; +namespace { +int get_mask_from_matmul(const megdnn::param::MatrixMul& param) { + return static_cast(param.transposeA) + + (static_cast(param.transposeB) * 2); +} +} + /* ================= MatrixMul ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul); -MEGDNN_OPR_INIT2(MatrixMul, "matrix_mul") +MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, + const ExecutionPolicy& policy, + const OperatorNodeConfig& config) + : Super{a->owner_graph(), config, "matrix_mul", {a, b}} { + init_megdnn_opr(*this, param); + m_policy = policy; + add_input({a, b}); +} + +SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, + const ExecutionPolicy& policy, + const OperatorNodeConfig& config) { + return a.insert_single_output_opr(a.node(), b.node(), param, + policy, config); +} void MatrixMul::init_output_dtype() { DType output_dtype = config().output_dtype(); @@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes( param ^= 1; }; MGB_TRY { - a = mo->get_workspace_in_bytes(i0, i1, out); + a = AlgoChooser::setup_algo({i0, i1, out}, + megdnn_opr(), this); + //! Here we just want to save the execution policy got from setup_algo, + //! while change the delaration of get_workspace_in_bytes may cause + //! many changes. + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i0, tparam.transposeA); - b = mo->get_workspace_in_bytes(i0, i1, out); + b = AlgoChooser::setup_algo({i0, i1, out}, + megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i1, tparam.transposeB); - c = mo->get_workspace_in_bytes(i0, i1, out); + c = AlgoChooser::setup_algo({i0, i1, out}, + megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i0, tparam.transposeA); - d = mo->get_workspace_in_bytes(i0, i1, out); + d = AlgoChooser::setup_algo({i0, i1, out}, + megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); } MGB_FINALLY({ tparam = this->param(); }); return std::max(std::max(a, b), std::max(c, d)); @@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() { MGB_TRY { transpose(inp0.layout, tparam.transposeA); transpose(inp1.layout, tparam.transposeB); + megdnn_opr()->execution_policy() = + m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; megdnn_opr()->exec(inp0, inp1, out, intl::get_megdnn_workspace_from_var(output(1))); } @@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { /* ================= BatchedMatrixMul ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul); -MEGDNN_OPR_INIT2(BatchedMatrixMul, "batched_matrix_mul") +BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, + const ExecutionPolicy& policy, + const OperatorNodeConfig& config) + : Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} { + init_megdnn_opr(*this, param); + m_policy = policy; + add_input({a, b}); +} + +SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, + const ExecutionPolicy& policy, + const OperatorNodeConfig& config) { + return a.insert_single_output_opr(a.node(), b.node(), + param, policy, config); +} void BatchedMatrixMul::add_input_layout_constraint() { auto check = [](const TensorLayout& ly) { @@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( param ^= 1; }; MGB_TRY { - a = mo->get_workspace_in_bytes(i0, i1, out); + a = AlgoChooser::setup_algo( + {i0, i1, out}, megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i0, tparam.transposeA); - b = mo->get_workspace_in_bytes(i0, i1, out); + b = AlgoChooser::setup_algo( + {i0, i1, out}, megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i1, tparam.transposeB); - c = mo->get_workspace_in_bytes(i0, i1, out); + c = AlgoChooser::setup_algo( + {i0, i1, out}, megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); transpose(i0, tparam.transposeA); - d = mo->get_workspace_in_bytes(i0, i1, out); + d = AlgoChooser::setup_algo( + {i0, i1, out}, megdnn_opr(), this); + const_cast(this) + ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = + megdnn_opr()->execution_policy(); } MGB_FINALLY({ tparam = this->param(); }); return std::max(std::max(a, b), std::max(c, d)); @@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() { MGB_TRY { transpose(inp0.layout, tparam.transposeA); transpose(inp1.layout, tparam.transposeB); + megdnn_opr()->execution_policy() = + m_cadidate_execution_policies[get_mask_from_matmul(tparam)]; megdnn_opr()->exec(inp0, inp1, out, intl::get_megdnn_workspace_from_var(output(1))); } diff --git a/src/opr/impl/blas.oprdecl b/src/opr/impl/blas.oprdecl index 7e201845..cfe62933 100644 --- a/src/opr/impl/blas.oprdecl +++ b/src/opr/impl/blas.oprdecl @@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul', 'performed and output shape is (n, a, c)') decl_opr('MatrixMul', + pyname='matrix_mul_v2', inputs=['opr0', 'opr1'], params='MatrixMul', desc='matrix multiplication', version=2, has_out_dtype=True) decl_opr('BatchedMatrixMul', + pyname='batched_matrix_mul_v2', inputs=['opr0', 'opr1'], params='MatrixMul', desc='batched matrix multiplication: input shapes should be ' @@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul', 'performed and output shape is (n, a, c)', version=2, has_out_dtype=True) +decl_opr('MatrixMul', + inputs=['opr0', 'opr1'], + params=[('param', 'MatrixMul'), + ('execution_polity', 'ExecutionPolicy')], + desc='matrix multiplication', + version=3, has_out_dtype=True) + +decl_opr('BatchedMatrixMul', + inputs=['opr0', 'opr1'], + params=[('param', 'MatrixMul'), + ('execution_polity', 'ExecutionPolicy')], + desc='batched matrix multiplication: input shapes should be ' + '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' + 'False); then :math:`n` independent matrix multiplications would be ' + 'performed and output shape is (n, a, c)', + version=3, has_out_dtype=True) + decl_opr('Dot', inputs=['opr0', 'opr1'], params='Empty', diff --git a/src/opr/impl/blas.sereg.h b/src/opr/impl/blas.sereg.h index 19c05426..5ffb29cf 100644 --- a/src/opr/impl/blas.sereg.h +++ b/src/opr/impl/blas.sereg.h @@ -10,7 +10,10 @@ */ #include "megbrain/opr/blas.h" +#include "megbrain/opr/param_defs.h" #include "megbrain/serialization/sereg.h" +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs/linalg.h" namespace mgb { namespace serialization { @@ -27,14 +30,70 @@ struct OprMaker { } }; +template +struct MakeMatrixMulCaller { + template + static VarNode* make(const cg::VarNodeArray& inputs, + const typename MegDNNConv::Param& param, + const megdnn::param::ExecutionPolicy& execution_policy, + const OperatorNodeConfig& config) { + if (inputs.size() == 2) { + return Opr::make(inputs[0], inputs[1], param, execution_policy, + config) + .node(); + } + return nullptr; + } +}; + +template +struct MatrixMulLoadDumpImpl { + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { + auto&& opr = opr_.cast_final_safe(); + ctx.write_param(opr.param()); + ctx.write_param(opr.execution_policy()); + } + + static VarNode* make(const cg::VarNodeArray& inputs, + const megdnn::param::MatrixMul& param, + const megdnn::param::ExecutionPolicy& execution_policy, + const OperatorNodeConfig& config) { + VarNode* ret = Maker::template make(inputs, param, + execution_policy, config); + mgb_assert(ret); + return ret; + } + + static cg::OperatorNodeBase* load(OprLoadContext& ctx, + const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + auto param = ctx.read_param(); + auto execution_policy = + ctx.read_param(); + return make(inputs, param, execution_policy, config)->owner_opr(); + } +}; + +template <> +struct OprLoadDumpImpl + : public MatrixMulLoadDumpImpl, + megdnn::MatrixMul> {}; +template <> +struct OprLoadDumpImpl + : public MatrixMulLoadDumpImpl< + opr::BatchedMatrixMul, + MakeMatrixMulCaller, + megdnn::BatchedMatrixMul> {}; + } // namespace serialization namespace opr { -using MatrixMulV2 = MatrixMul; -using BatchedMatrixMulV2 = BatchedMatrixMul; -MGB_SEREG_OPR(MatrixMulV2, 2); -MGB_SEREG_OPR(BatchedMatrixMulV2, 2); +using MatrixMulV3 = MatrixMul; +using BatchedMatrixMulV3 = BatchedMatrixMul; +MGB_SEREG_OPR(MatrixMulV3, 2); +MGB_SEREG_OPR(BatchedMatrixMulV3, 2); MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(SVD, 1); diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 1fa9e831..6402335d 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() { } #undef IMPL_CONV -#undef MGB_FOREACH_FASTRUN_OPR // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index e0166b8b..65f1e171 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -98,7 +98,7 @@ namespace serialization { return nullptr; } }; - + template, @@ -292,7 +292,7 @@ namespace serialization { return nullptr; } }; - + template, class Maker2=MakeLocalShareCallerEmpty, diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index d194f657..74d619c1 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -251,6 +251,7 @@ AlgoChooser::ExeContext::choose_by_heuristic(bool reproducible) const { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( opr->owner_graph(), opr->comp_node(), opr->execution_policy().workspace_limit); + m_megdnn_opr->execution_policy() = {}; return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( args..., workspace_limit, reproducible), m_layouts); diff --git a/src/opr/include/megbrain/opr/blas.h b/src/opr/include/megbrain/opr/blas.h index 88615fa5..f007195c 100644 --- a/src/opr/include/megbrain/opr/blas.h +++ b/src/opr/include/megbrain/opr/blas.h @@ -12,6 +12,7 @@ #pragma once #include "megbrain/exception.h" +#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/tensor.h" #include "megbrain/graph.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" @@ -24,51 +25,58 @@ namespace opr { /*! * \brief matrix_mul(trans0(opr0), trans1(opr1)) */ -MGB_DEFINE_OPR_CLASS(MatrixMul, - intl::MegDNNOprWrapperFwd) // { - - public: - - MatrixMul(VarNode *opr0, VarNode *opr1, - const Param ¶m, const OperatorNodeConfig &config); - - static SymbolVar make(SymbolVar opr0, SymbolVar opr1, - const Param ¶m = {}, - const OperatorNodeConfig &config = {}); - private: - void add_input_layout_constraint() override; - void scn_do_execute() override; - void init_output_dtype() override; - size_t get_workspace_size_bytes( - const TensorShapeArray &input_shapes, - const TensorShapeArray &output_shapes) const override; - - static bool check_layout(const TensorLayout &layout, int transpose); +MGB_DEFINE_OPR_CLASS(MatrixMul, intl::MegDNNOprWrapperFwd, + public mixin::AlgoChooserHelper) // { +public: + using AlgorithmInfo = megdnn::detail::Algorithm::Info; + MatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, + const ExecutionPolicy& policy, const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar opr0, SymbolVar opr1, + const Param& param = {}, + const ExecutionPolicy& policy = {}, + const OperatorNodeConfig& config = {}); +private: + void add_input_layout_constraint() override; + void scn_do_execute() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) + const override; + static bool check_layout(const TensorLayout& layout, int transpose); + + //! store the policy of all transpose situations + megdnn::MatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; }; /*! * \brief batched matrix multiplication on 3D inputs */ MGB_DEFINE_OPR_CLASS(BatchedMatrixMul, - intl::MegDNNOprWrapperFwd) // { - - public: - - BatchedMatrixMul(VarNode *opr0, VarNode *opr1, - const Param ¶m, const OperatorNodeConfig &config); - - static SymbolVar make(SymbolVar opr0, SymbolVar opr1, - const Param ¶m = {}, - const OperatorNodeConfig &config = {}); - private: - void add_input_layout_constraint() override; - void init_output_dtype() override; - void scn_do_execute() override; - size_t get_workspace_size_bytes( - const TensorShapeArray &input_shapes, - const TensorShapeArray &output_shapes) const override; - - static bool check_layout(const TensorLayout &layout, bool transpose); + intl::MegDNNOprWrapperFwd, + public mixin::AlgoChooserHelper) // { +public: + using AlgorithmInfo = megdnn::detail::Algorithm::Info; + BatchedMatrixMul(VarNode* opr0, VarNode* opr1, const Param& param, + const ExecutionPolicy& policy, + const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar opr0, SymbolVar opr1, + const Param& param = {}, + const ExecutionPolicy& policy = {}, + const OperatorNodeConfig& config = {}); + +private: + void add_input_layout_constraint() override; + void init_output_dtype() override; + void scn_do_execute() override; + size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) + const override; + + static bool check_layout(const TensorLayout& layout, bool transpose); + //! store the policy of all transpose situations + megdnn::BatchedMatrixMul::ExecutionPolicy m_cadidate_execution_policies[4]; }; /*! @@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd) // { } // mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index b7687466..1a193921 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -14,6 +14,7 @@ #include "megbrain/opr/search_policy/profiler.h" #include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/blas.h" template struct MegDNNOpr2MGBOpr; diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index 37272b20..a59ba5eb 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -18,26 +18,31 @@ #include "megbrain/comp_node.h" #include "megdnn/basic_types.h" +#include "megdnn/oprs/linalg.h" #include "megdnn/oprs/nn.h" namespace mgb { namespace opr { -#define MGB_FOREACH_FASTRUN_OPR(cb) \ - cb(ConvolutionForward); \ - cb(ConvBiasForward); \ - cb(ConvolutionBackwardData); \ - cb(ConvolutionBackwardFilter); \ - cb(Convolution3DForward); \ - cb(Convolution3DBackwardData); \ - cb(Convolution3DBackwardFilter); \ - cb(LocalShareForward); \ - cb(LocalShareBackwardData); \ - cb(LocalShareBackwardFilter); \ - cb(DeformableConvForward); \ - cb(DeformableConvBackwardFilter); \ - cb(DeformableConvBackwardData); \ - cb(BatchConvBiasForward); +// clang-format off +#define MGB_FOREACH_FASTRUN_OPR(cb) \ + cb(ConvolutionForward) \ + cb(ConvBiasForward) \ + cb(ConvolutionBackwardData) \ + cb(ConvolutionBackwardFilter) \ + cb(Convolution3DForward) \ + cb(Convolution3DBackwardData) \ + cb(Convolution3DBackwardFilter) \ + cb(LocalShareForward) \ + cb(LocalShareBackwardData) \ + cb(LocalShareBackwardFilter) \ + cb(DeformableConvForward) \ + cb(DeformableConvBackwardFilter) \ + cb(DeformableConvBackwardData) \ + cb(BatchConvBiasForward) \ + cb(MatrixMul) \ + cb(BatchedMatrixMul) +// clang-format on template struct OprArityTrait; @@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); INST_ARITY(megdnn::ConvBias, 4, 1); INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); +INST_ARITY(megdnn::MatrixMul, 2, 1); +INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); #undef INST_ARITY diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index 71c4fb2b..4373e9a7 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) { if (DTypeTrait::enumv == DTypeEnum::Int16) { config.output_dtype(dtype::Int16()); } - auto z = opr::MatrixMul::make(x, y, {}, config); + auto z = opr::MatrixMul::make(x, y, {}, {}, config); HostTensorND host_z; auto func = graph->compile({make_callback_copy(z, host_z)}); @@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) { trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0; trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0; - auto z = opr::BatchedMatrixMul::make(x, y, {}, OperatorNodeConfig{}); + auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{}); HostTensorND host_z; auto func = graph->compile({make_callback_copy(z, host_z)}); auto run = [&](size_t B, size_t M, size_t K, size_t N) { @@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) { run_sgemm_test(true, true); } +TEST(TestOprDNN, MatrixMulExePolicy) { + using Param = opr::MatrixMul::Param; + Param param; + using Policy = opr::MatrixMul::ExecutionPolicy; + using S = Policy::Strategy; + + auto cn = CompNode::load("cpux"); + +#if MGB_ENABLE_FASTRUN + for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, + S::PROFILE_HEURISTIC}) { +#else + for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) { +#endif + + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp), cn) + .rename(name); + }; + + auto A = mkvar("A", {32, 64}); + auto B = mkvar("B", {64, 32}); + + Policy policy; + policy.strategy = strategy; + + auto C = opr::MatrixMul::make(A, B, param, policy); + HostTensorND host_c; + auto func = graph->compile({make_callback_copy(C, host_c)}); + func->execute(); + } +} + + TEST(TestOprBlas, BatchedMatrixMulFp32_NN) { run_batched_sgemm_test(false, false); } -- GitLab