提交 5063a206 编写于 作者: M Megvii Engine Team

feat(mgb): add fastrun support for matmul/batchedmatrixmul

GitOrigin-RevId: a48ea9bff6d852ffecbbe96982e83a0738c0d366
上级 409a8772
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/utils/shared_set.h" #include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
...@@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, ...@@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
} // anonymous namespace } // anonymous namespace
#define MGB_FOREACH_FASTRUN_OPR(cb) \
cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \
cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \
cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \
cb(LocalShareForward), cb(LocalShareBackwardData), \
cb(LocalShareBackwardFilter), cb(DeformableConvForward), \
cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
cb(BatchConvBiasForward),
void gopt::modify_opr_algo_strategy_inplace( void gopt::modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars, const VarNodeArrayView& dest_vars,
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
...@@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace( ...@@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace(
modifiers = { modifiers = {
#define CONV(t) \ #define CONV(t) \
{opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ {opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
std::placeholders::_1, strategy)} std::placeholders::_1, strategy)},
MGB_FOREACH_FASTRUN_OPR(CONV) MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV #undef CONV
}; };
...@@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace( ...@@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace(
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
modifiers = { modifiers = {
#define CONV(t) \ #define CONV(t) \
{opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>} {opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>},
MGB_FOREACH_FASTRUN_OPR(CONV) MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV #undef CONV
}; };
...@@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace( ...@@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace(
dep_iter.add(i); dep_iter.add(i);
} }
} }
#undef MGB_FOREACH_FASTRUN_OPR
/* ================ ParamRedistributePass ================ */ /* ================ ParamRedistributePass ================ */
const char* ParamRedistributePass::name() const { const char* ParamRedistributePass::name() const {
...@@ -790,8 +781,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -790,8 +781,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->name().c_str(), new_inp[1]->name().c_str(),
new_inp[1]->owner_opr()->name().c_str()); new_inp[1]->owner_opr()->name().c_str());
auto new_deconv_opr = opr::ConvolutionBackwardData::make( auto new_deconv_opr = opr::ConvolutionBackwardData::make(
new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(), new_inp[0], new_inp[1], new_param,
deconv_opr.config()); deconv_opr.execution_policy(), deconv_opr.config());
return new_deconv_opr.node()->owner_opr(); return new_deconv_opr.node()->owner_opr();
}; };
...@@ -813,20 +804,20 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -813,20 +804,20 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->owner_opr()->name().c_str()); new_inp[1]->owner_opr()->name().c_str());
if(opr->input().size() == 2) { if(opr->input().size() == 2) {
auto new_conv_opr = opr::ConvBias::make( auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(), new_inp[0], new_inp[1], new_param,
convbias_opr.config()); convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr(); return new_conv_opr.node()->owner_opr();
} else if(opr->input().size() == 3) { } else if(opr->input().size() == 3) {
auto new_conv_opr = opr::ConvBias::make( auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(), new_inp[0], new_inp[1], new_inp[2], new_param,
convbias_opr.config()); convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr(); return new_conv_opr.node()->owner_opr();
} else { } else {
mgb_assert(opr->input().size() == 4, "invalid input size %zu", mgb_assert(opr->input().size() == 4, "invalid input size %zu",
opr->input().size()); opr->input().size());
auto new_conv_opr = opr::ConvBias::make( auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(), new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param,
convbias_opr.config()); convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr(); return new_conv_opr.node()->owner_opr();
} }
}; };
...@@ -841,7 +832,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -841,7 +832,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
megdnn::param::MatrixMul::ComputeMode::FLOAT32; megdnn::param::MatrixMul::ComputeMode::FLOAT32;
} }
auto new_matmul_opr = opr::MatrixMul::make( auto new_matmul_opr = opr::MatrixMul::make(
new_inp[0], new_inp[1], new_param, matmul_opr.config()); new_inp[0], new_inp[1], new_param,
matmul_opr.execution_policy(), matmul_opr.config());
return new_matmul_opr.node()->owner_opr(); return new_matmul_opr.node()->owner_opr();
}; };
...@@ -864,7 +856,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -864,7 +856,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->name().c_str(), new_inp[1]->name().c_str(),
new_inp[1]->owner_opr()->name().c_str()); new_inp[1]->owner_opr()->name().c_str());
auto new_matmul_opr = opr::BatchedMatrixMul::make( auto new_matmul_opr = opr::BatchedMatrixMul::make(
new_inp[0], new_inp[1], new_param, matmul_opr.config()); new_inp[0], new_inp[1], new_param,
matmul_opr.execution_policy(), matmul_opr.config());
return new_matmul_opr.node()->owner_opr(); return new_matmul_opr.node()->owner_opr();
}; };
...@@ -915,8 +908,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -915,8 +908,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_mat->owner_opr()->input(0)->dtype() == dtype::Float32()) new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
new_mat = new_mat->owner_opr()->input(0); new_mat = new_mat->owner_opr()->input(0);
else else
new_mat = new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); .node();
} }
SymbolVar new_warp; SymbolVar new_warp;
if (new_inp.size() == 3) { if (new_inp.size() == 3) {
...@@ -944,8 +937,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( ...@@ -944,8 +937,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_map->owner_opr()->input(0)->dtype() == dtype::Float32()) new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
new_map = new_map->owner_opr()->input(0); new_map = new_map->owner_opr()->input(0);
else else
new_map = new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node(); .node();
} }
SymbolVar new_remap; SymbolVar new_remap;
......
...@@ -18,15 +18,41 @@ ...@@ -18,15 +18,41 @@
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/search_policy/algo_chooser.h"
#include "megbrain/opr/search_policy/profiler.h"
#include "./internal/megdnn_opr_wrapper.inl" #include "./internal/megdnn_opr_wrapper.inl"
#include "./search_policy/workspace_need_limit_getter.inl"
#include "megdnn/oprs/linalg.h"
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
namespace {
int get_mask_from_matmul(const megdnn::param::MatrixMul& param) {
return static_cast<int>(param.transposeA) +
(static_cast<int>(param.transposeB) * 2);
}
}
/* ================= MatrixMul ================= */ /* ================= MatrixMul ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul); MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul);
MEGDNN_OPR_INIT2(MatrixMul, "matrix_mul") MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
}
SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param,
policy, config);
}
void MatrixMul::init_output_dtype() { void MatrixMul::init_output_dtype() {
DType output_dtype = config().output_dtype(); DType output_dtype = config().output_dtype();
...@@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes( ...@@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes(
param ^= 1; param ^= 1;
}; };
MGB_TRY { MGB_TRY {
a = mo->get_workspace_in_bytes(i0, i1, out); a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
//! Here we just want to save the execution policy got from setup_algo,
//! while change the delaration of get_workspace_in_bytes may cause
//! many changes.
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA); transpose(i0, tparam.transposeA);
b = mo->get_workspace_in_bytes(i0, i1, out); b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i1, tparam.transposeB); transpose(i1, tparam.transposeB);
c = mo->get_workspace_in_bytes(i0, i1, out); c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA); transpose(i0, tparam.transposeA);
d = mo->get_workspace_in_bytes(i0, i1, out); d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
} }
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d)); return std::max(std::max(a, b), std::max(c, d));
...@@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() { ...@@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() {
MGB_TRY { MGB_TRY {
transpose(inp0.layout, tparam.transposeA); transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB); transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(inp0, inp1, out, megdnn_opr()->exec(inp0, inp1, out,
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
...@@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { ...@@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) {
/* ================= BatchedMatrixMul ================= */ /* ================= BatchedMatrixMul ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul);
MEGDNN_OPR_INIT2(BatchedMatrixMul, "batched_matrix_mul") BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
}
SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(),
param, policy, config);
}
void BatchedMatrixMul::add_input_layout_constraint() { void BatchedMatrixMul::add_input_layout_constraint() {
auto check = [](const TensorLayout& ly) { auto check = [](const TensorLayout& ly) {
...@@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( ...@@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes(
param ^= 1; param ^= 1;
}; };
MGB_TRY { MGB_TRY {
a = mo->get_workspace_in_bytes(i0, i1, out); a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA); transpose(i0, tparam.transposeA);
b = mo->get_workspace_in_bytes(i0, i1, out); b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i1, tparam.transposeB); transpose(i1, tparam.transposeB);
c = mo->get_workspace_in_bytes(i0, i1, out); c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA); transpose(i0, tparam.transposeA);
d = mo->get_workspace_in_bytes(i0, i1, out); d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
} }
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d)); return std::max(std::max(a, b), std::max(c, d));
...@@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() { ...@@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_TRY { MGB_TRY {
transpose(inp0.layout, tparam.transposeA); transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB); transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(inp0, inp1, out, megdnn_opr()->exec(inp0, inp1, out,
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
......
...@@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul', ...@@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul',
'performed and output shape is (n, a, c)') 'performed and output shape is (n, a, c)')
decl_opr('MatrixMul', decl_opr('MatrixMul',
pyname='matrix_mul_v2',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
params='MatrixMul', params='MatrixMul',
desc='matrix multiplication', desc='matrix multiplication',
version=2, has_out_dtype=True) version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul', decl_opr('BatchedMatrixMul',
pyname='batched_matrix_mul_v2',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
params='MatrixMul', params='MatrixMul',
desc='batched matrix multiplication: input shapes should be ' desc='batched matrix multiplication: input shapes should be '
...@@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul', ...@@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul',
'performed and output shape is (n, a, c)', 'performed and output shape is (n, a, c)',
version=2, has_out_dtype=True) version=2, has_out_dtype=True)
decl_opr('MatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='matrix multiplication',
version=3, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='batched matrix multiplication: input shapes should be '
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=3, has_out_dtype=True)
decl_opr('Dot', decl_opr('Dot',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
params='Empty', params='Empty',
......
...@@ -10,7 +10,10 @@ ...@@ -10,7 +10,10 @@
*/ */
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/linalg.h"
namespace mgb { namespace mgb {
namespace serialization { namespace serialization {
...@@ -27,14 +30,70 @@ struct OprMaker<opr::SVD, 1> { ...@@ -27,14 +30,70 @@ struct OprMaker<opr::SVD, 1> {
} }
}; };
template <class MegDNNConv = megdnn::MatrixMul>
struct MakeMatrixMulCaller {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, execution_policy,
config)
.node();
}
return nullptr;
}
};
template <class Opr, class Maker, class MegDNNMatrixMul>
struct MatrixMulLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<megdnn::param::MatrixMul>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}
static VarNode* make(const cg::VarNodeArray& inputs,
const megdnn::param::MatrixMul& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker::template make<Opr>(inputs, param,
execution_policy, config);
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::MatrixMul>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::MatrixMul, 2>
: public MatrixMulLoadDumpImpl<opr::MatrixMul,
MakeMatrixMulCaller<megdnn::MatrixMul>,
megdnn::MatrixMul> {};
template <>
struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
: public MatrixMulLoadDumpImpl<
opr::BatchedMatrixMul,
MakeMatrixMulCaller<megdnn::BatchedMatrixMul>,
megdnn::BatchedMatrixMul> {};
} // namespace serialization } // namespace serialization
namespace opr { namespace opr {
using MatrixMulV2 = MatrixMul; using MatrixMulV3 = MatrixMul;
using BatchedMatrixMulV2 = BatchedMatrixMul; using BatchedMatrixMulV3 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV2, 2); MGB_SEREG_OPR(MatrixMulV3, 2);
MGB_SEREG_OPR(BatchedMatrixMulV2, 2); MGB_SEREG_OPR(BatchedMatrixMulV3, 2);
MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1); MGB_SEREG_OPR(SVD, 1);
......
...@@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() { ...@@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() {
} }
#undef IMPL_CONV #undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -98,7 +98,7 @@ namespace serialization { ...@@ -98,7 +98,7 @@ namespace serialization {
return nullptr; return nullptr;
} }
}; };
template<class Opr, class Maker0, class MegDNNConv, template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeConvCallerEmpty<MegDNNConv>, class Maker1=MakeConvCallerEmpty<MegDNNConv>,
...@@ -292,7 +292,7 @@ namespace serialization { ...@@ -292,7 +292,7 @@ namespace serialization {
return nullptr; return nullptr;
} }
}; };
template<class Opr, class Maker0, class MegDNNConv, template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>, class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>,
class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>, class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>,
......
...@@ -251,6 +251,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { ...@@ -251,6 +251,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(), opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit); opr->execution_policy().workspace_limit);
m_megdnn_opr->execution_policy() = {};
return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( return APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, reproducible), args..., workspace_limit, reproducible),
m_layouts); m_layouts);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
...@@ -24,51 +25,58 @@ namespace opr { ...@@ -24,51 +25,58 @@ namespace opr {
/*! /*!
* \brief matrix_mul(trans0(opr0), trans1(opr1)) * \brief matrix_mul(trans0(opr0), trans1(opr1))
*/ */
MGB_DEFINE_OPR_CLASS(MatrixMul, MGB_DEFINE_OPR_CLASS(MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>,
intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>) // { public mixin::AlgoChooserHelper) // {
public:
public: using AlgorithmInfo = megdnn::detail::Algorithm::Info;
MatrixMul(VarNode* opr0, VarNode* opr1, const Param& param,
MatrixMul(VarNode *opr0, VarNode *opr1, const ExecutionPolicy& policy, const OperatorNodeConfig& config);
const Param &param, const OperatorNodeConfig &config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
static SymbolVar make(SymbolVar opr0, SymbolVar opr1, const Param& param = {},
const Param &param = {}, const ExecutionPolicy& policy = {},
const OperatorNodeConfig &config = {}); const OperatorNodeConfig& config = {});
private: private:
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_dtype() override; void init_output_dtype() override;
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray &input_shapes, const TensorShapeArray& output_shapes)
const TensorShapeArray &output_shapes) const override; const override;
static bool check_layout(const TensorLayout& layout, int transpose);
static bool check_layout(const TensorLayout &layout, int transpose);
//! store the policy of all transpose situations
megdnn::MatrixMul::ExecutionPolicy m_cadidate_execution_policies[4];
}; };
/*! /*!
* \brief batched matrix multiplication on 3D inputs * \brief batched matrix multiplication on 3D inputs
*/ */
MGB_DEFINE_OPR_CLASS(BatchedMatrixMul, MGB_DEFINE_OPR_CLASS(BatchedMatrixMul,
intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>) // { intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>,
public mixin::AlgoChooserHelper) // {
public: public:
using AlgorithmInfo = megdnn::detail::Algorithm::Info;
BatchedMatrixMul(VarNode *opr0, VarNode *opr1, BatchedMatrixMul(VarNode* opr0, VarNode* opr1, const Param& param,
const Param &param, const OperatorNodeConfig &config); const ExecutionPolicy& policy,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const Param &param = {}, static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const OperatorNodeConfig &config = {}); const Param& param = {},
private: const ExecutionPolicy& policy = {},
void add_input_layout_constraint() override; const OperatorNodeConfig& config = {});
void init_output_dtype() override;
void scn_do_execute() override; private:
size_t get_workspace_size_bytes( void add_input_layout_constraint() override;
const TensorShapeArray &input_shapes, void init_output_dtype() override;
const TensorShapeArray &output_shapes) const override; void scn_do_execute() override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
static bool check_layout(const TensorLayout &layout, bool transpose); const TensorShapeArray& output_shapes)
const override;
static bool check_layout(const TensorLayout& layout, bool transpose);
//! store the policy of all transpose situations
megdnn::BatchedMatrixMul::ExecutionPolicy m_cadidate_execution_policies[4];
}; };
/*! /*!
...@@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // { ...@@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // {
} // mgb } // mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/opr/search_policy/profiler.h" #include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/blas.h"
template <class MegDNNOpr> template <class MegDNNOpr>
struct MegDNNOpr2MGBOpr; struct MegDNNOpr2MGBOpr;
......
...@@ -18,26 +18,31 @@ ...@@ -18,26 +18,31 @@
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
namespace mgb { namespace mgb {
namespace opr { namespace opr {
#define MGB_FOREACH_FASTRUN_OPR(cb) \ // clang-format off
cb(ConvolutionForward); \ #define MGB_FOREACH_FASTRUN_OPR(cb) \
cb(ConvBiasForward); \ cb(ConvolutionForward) \
cb(ConvolutionBackwardData); \ cb(ConvBiasForward) \
cb(ConvolutionBackwardFilter); \ cb(ConvolutionBackwardData) \
cb(Convolution3DForward); \ cb(ConvolutionBackwardFilter) \
cb(Convolution3DBackwardData); \ cb(Convolution3DForward) \
cb(Convolution3DBackwardFilter); \ cb(Convolution3DBackwardData) \
cb(LocalShareForward); \ cb(Convolution3DBackwardFilter) \
cb(LocalShareBackwardData); \ cb(LocalShareForward) \
cb(LocalShareBackwardFilter); \ cb(LocalShareBackwardData) \
cb(DeformableConvForward); \ cb(LocalShareBackwardFilter) \
cb(DeformableConvBackwardFilter); \ cb(DeformableConvForward) \
cb(DeformableConvBackwardData); \ cb(DeformableConvBackwardFilter) \
cb(BatchConvBiasForward); cb(DeformableConvBackwardData) \
cb(BatchConvBiasForward) \
cb(MatrixMul) \
cb(BatchedMatrixMul)
// clang-format on
template <typename Opr> template <typename Opr>
struct OprArityTrait; struct OprArityTrait;
...@@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); ...@@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1);
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); INST_ARITY(megdnn::BatchConvBiasForward, 4, 1);
INST_ARITY(megdnn::ConvBias, 4, 1); INST_ARITY(megdnn::ConvBias, 4, 1);
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3);
INST_ARITY(megdnn::MatrixMul, 2, 1);
INST_ARITY(megdnn::BatchedMatrixMul, 2, 1);
#undef INST_ARITY #undef INST_ARITY
......
...@@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) { ...@@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) {
if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) { if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) {
config.output_dtype(dtype::Int16()); config.output_dtype(dtype::Int16());
} }
auto z = opr::MatrixMul::make(x, y, {}, config); auto z = opr::MatrixMul::make(x, y, {}, {}, config);
HostTensorND host_z; HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)}); auto func = graph->compile({make_callback_copy(z, host_z)});
...@@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) { ...@@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) {
trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0; trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0;
trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0; trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0;
auto z = opr::BatchedMatrixMul::make(x, y, {}, OperatorNodeConfig{}); auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{});
HostTensorND host_z; HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)}); auto func = graph->compile({make_callback_copy(z, host_z)});
auto run = [&](size_t B, size_t M, size_t K, size_t N) { auto run = [&](size_t B, size_t M, size_t K, size_t N) {
...@@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) { ...@@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) {
run_sgemm_test(true, true); run_sgemm_test(true, true);
} }
TEST(TestOprDNN, MatrixMulExePolicy) {
using Param = opr::MatrixMul::Param;
Param param;
using Policy = opr::MatrixMul::ExecutionPolicy;
using S = Policy::Strategy;
auto cn = CompNode::load("cpux");
#if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
#else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
#endif
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
.rename(name);
};
auto A = mkvar("A", {32, 64});
auto B = mkvar("B", {64, 32});
Policy policy;
policy.strategy = strategy;
auto C = opr::MatrixMul::make(A, B, param, policy);
HostTensorND host_c;
auto func = graph->compile({make_callback_copy(C, host_c)});
func->execute();
}
}
TEST(TestOprBlas, BatchedMatrixMulFp32_NN) { TEST(TestOprBlas, BatchedMatrixMulFp32_NN) {
run_batched_sgemm_test(false, false); run_batched_sgemm_test(false, false);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册