提交 5063a206 编写于 作者: M Megvii Engine Team

feat(mgb): add fastrun support for matmul/batchedmatrixmul

GitOrigin-RevId: a48ea9bff6d852ffecbbe96982e83a0738c0d366
上级 409a8772
......@@ -16,6 +16,7 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
......@@ -149,15 +150,6 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,
} // anonymous namespace
#define MGB_FOREACH_FASTRUN_OPR(cb) \
cb(ConvolutionForward), cb(ConvBiasForward), cb(ConvolutionBackwardData), \
cb(ConvolutionBackwardFilter), cb(Convolution3DForward), \
cb(Convolution3DBackwardData), cb(Convolution3DBackwardFilter), \
cb(LocalShareForward), cb(LocalShareBackwardData), \
cb(LocalShareBackwardFilter), cb(DeformableConvForward), \
cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \
cb(BatchConvBiasForward),
void gopt::modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars,
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
......@@ -171,7 +163,7 @@ void gopt::modify_opr_algo_strategy_inplace(
modifiers = {
#define CONV(t) \
{opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \
std::placeholders::_1, strategy)}
std::placeholders::_1, strategy)},
MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV
};
......@@ -209,7 +201,7 @@ void gopt::set_opr_algo_workspace_limit_inplace(
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)>
modifiers = {
#define CONV(t) \
{opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>}
{opr::t::typeinfo(), &inplace_conv_opr_workspace_limit_modifier<opr::t>},
MGB_FOREACH_FASTRUN_OPR(CONV)
#undef CONV
};
......@@ -226,7 +218,6 @@ void gopt::set_opr_algo_workspace_limit_inplace(
dep_iter.add(i);
}
}
#undef MGB_FOREACH_FASTRUN_OPR
/* ================ ParamRedistributePass ================ */
const char* ParamRedistributePass::name() const {
......@@ -790,8 +781,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->name().c_str(),
new_inp[1]->owner_opr()->name().c_str());
auto new_deconv_opr = opr::ConvolutionBackwardData::make(
new_inp[0], new_inp[1], new_param, deconv_opr.execution_policy(),
deconv_opr.config());
new_inp[0], new_inp[1], new_param,
deconv_opr.execution_policy(), deconv_opr.config());
return new_deconv_opr.node()->owner_opr();
};
......@@ -813,20 +804,20 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->owner_opr()->name().c_str());
if(opr->input().size() == 2) {
auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_param, convbias_opr.execution_policy(),
convbias_opr.config());
new_inp[0], new_inp[1], new_param,
convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr();
} else if(opr->input().size() == 3) {
auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_inp[2], new_param, convbias_opr.execution_policy(),
convbias_opr.config());
new_inp[0], new_inp[1], new_inp[2], new_param,
convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr();
} else {
mgb_assert(opr->input().size() == 4, "invalid input size %zu",
opr->input().size());
auto new_conv_opr = opr::ConvBias::make(
new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param, convbias_opr.execution_policy(),
convbias_opr.config());
new_inp[0], new_inp[1], new_inp[2], new_inp[3], new_param,
convbias_opr.execution_policy(), convbias_opr.config());
return new_conv_opr.node()->owner_opr();
}
};
......@@ -841,7 +832,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
megdnn::param::MatrixMul::ComputeMode::FLOAT32;
}
auto new_matmul_opr = opr::MatrixMul::make(
new_inp[0], new_inp[1], new_param, matmul_opr.config());
new_inp[0], new_inp[1], new_param,
matmul_opr.execution_policy(), matmul_opr.config());
return new_matmul_opr.node()->owner_opr();
};
......@@ -864,7 +856,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_inp[1]->name().c_str(),
new_inp[1]->owner_opr()->name().c_str());
auto new_matmul_opr = opr::BatchedMatrixMul::make(
new_inp[0], new_inp[1], new_param, matmul_opr.config());
new_inp[0], new_inp[1], new_param,
matmul_opr.execution_policy(), matmul_opr.config());
return new_matmul_opr.node()->owner_opr();
};
......@@ -915,8 +908,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_mat->owner_opr()->input(0)->dtype() == dtype::Float32())
new_mat = new_mat->owner_opr()->input(0);
else
new_mat =
opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
new_mat = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
.node();
}
SymbolVar new_warp;
if (new_inp.size() == 3) {
......@@ -944,8 +937,8 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
new_map->owner_opr()->input(0)->dtype() == dtype::Float32())
new_map = new_map->owner_opr()->input(0);
else
new_map =
opr::TypeCvt::make(new_inp[1], dtype::Float32(), {}).node();
new_map = opr::TypeCvt::make(new_inp[1], dtype::Float32(), {})
.node();
}
SymbolVar new_remap;
......
......@@ -18,15 +18,41 @@
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/search_policy/algo_chooser.h"
#include "megbrain/opr/search_policy/profiler.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "./search_policy/workspace_need_limit_getter.inl"
#include "megdnn/oprs/linalg.h"
using namespace mgb;
using namespace opr;
namespace {
int get_mask_from_matmul(const megdnn::param::MatrixMul& param) {
return static_cast<int>(param.transposeA) +
(static_cast<int>(param.transposeB) * 2);
}
}
/* ================= MatrixMul ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul);
MEGDNN_OPR_INIT2(MatrixMul, "matrix_mul")
MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
}
SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param,
policy, config);
}
void MatrixMul::init_output_dtype() {
DType output_dtype = config().output_dtype();
......@@ -72,13 +98,32 @@ size_t MatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
a = mo->get_workspace_in_bytes(i0, i1, out);
a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
//! Here we just want to save the execution policy got from setup_algo,
//! while change the delaration of get_workspace_in_bytes may cause
//! many changes.
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA);
b = mo->get_workspace_in_bytes(i0, i1, out);
b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i1, tparam.transposeB);
c = mo->get_workspace_in_bytes(i0, i1, out);
c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA);
d = mo->get_workspace_in_bytes(i0, i1, out);
d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
......@@ -100,6 +145,8 @@ void MatrixMul::scn_do_execute() {
MGB_TRY {
transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(inp0, inp1, out,
intl::get_megdnn_workspace_from_var(output(1)));
}
......@@ -134,7 +181,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) {
/* ================= BatchedMatrixMul ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul);
MEGDNN_OPR_INIT2(BatchedMatrixMul, "batched_matrix_mul")
BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
}
SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(),
param, policy, config);
}
void BatchedMatrixMul::add_input_layout_constraint() {
auto check = [](const TensorLayout& ly) {
......@@ -191,13 +252,29 @@ size_t BatchedMatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
a = mo->get_workspace_in_bytes(i0, i1, out);
a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA);
b = mo->get_workspace_in_bytes(i0, i1, out);
b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i1, tparam.transposeB);
c = mo->get_workspace_in_bytes(i0, i1, out);
c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
transpose(i0, tparam.transposeA);
d = mo->get_workspace_in_bytes(i0, i1, out);
d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
......@@ -220,6 +297,8 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_TRY {
transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(inp0, inp1, out,
intl::get_megdnn_workspace_from_var(output(1)));
}
......
......@@ -14,12 +14,14 @@ decl_opr('BatchedMatrixMul',
'performed and output shape is (n, a, c)')
decl_opr('MatrixMul',
pyname='matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='matrix multiplication',
version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
pyname='batched_matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='batched matrix multiplication: input shapes should be '
......@@ -28,6 +30,23 @@ decl_opr('BatchedMatrixMul',
'performed and output shape is (n, a, c)',
version=2, has_out_dtype=True)
decl_opr('MatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='matrix multiplication',
version=3, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='batched matrix multiplication: input shapes should be '
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=3, has_out_dtype=True)
decl_opr('Dot',
inputs=['opr0', 'opr1'],
params='Empty',
......
......@@ -10,7 +10,10 @@
*/
#include "megbrain/opr/blas.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/linalg.h"
namespace mgb {
namespace serialization {
......@@ -27,14 +30,70 @@ struct OprMaker<opr::SVD, 1> {
}
};
template <class MegDNNConv = megdnn::MatrixMul>
struct MakeMatrixMulCaller {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, execution_policy,
config)
.node();
}
return nullptr;
}
};
template <class Opr, class Maker, class MegDNNMatrixMul>
struct MatrixMulLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<megdnn::param::MatrixMul>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}
static VarNode* make(const cg::VarNodeArray& inputs,
const megdnn::param::MatrixMul& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker::template make<Opr>(inputs, param,
execution_policy, config);
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::MatrixMul>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::MatrixMul, 2>
: public MatrixMulLoadDumpImpl<opr::MatrixMul,
MakeMatrixMulCaller<megdnn::MatrixMul>,
megdnn::MatrixMul> {};
template <>
struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
: public MatrixMulLoadDumpImpl<
opr::BatchedMatrixMul,
MakeMatrixMulCaller<megdnn::BatchedMatrixMul>,
megdnn::BatchedMatrixMul> {};
} // namespace serialization
namespace opr {
using MatrixMulV2 = MatrixMul;
using BatchedMatrixMulV2 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV2, 2);
MGB_SEREG_OPR(BatchedMatrixMulV2, 2);
using MatrixMulV3 = MatrixMul;
using BatchedMatrixMulV3 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV3, 2);
MGB_SEREG_OPR(BatchedMatrixMulV3, 2);
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);
......
......@@ -1636,6 +1636,5 @@ void BatchConvBiasForward::init_output_format() {
}
#undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -98,7 +98,7 @@ namespace serialization {
return nullptr;
}
};
template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeConvCallerEmpty<MegDNNConv>,
......@@ -292,7 +292,7 @@ namespace serialization {
return nullptr;
}
};
template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>,
class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>,
......
......@@ -251,6 +251,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit);
m_megdnn_opr->execution_policy() = {};
return APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, reproducible),
m_layouts);
......
......@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/exception.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/tensor.h"
#include "megbrain/graph.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
......@@ -24,51 +25,58 @@ namespace opr {
/*!
* \brief matrix_mul(trans0(opr0), trans1(opr1))
*/
MGB_DEFINE_OPR_CLASS(MatrixMul,
intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>) // {
public:
MatrixMul(VarNode *opr0, VarNode *opr1,
const Param &param, const OperatorNodeConfig &config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const Param &param = {},
const OperatorNodeConfig &config = {});
private:
void add_input_layout_constraint() override;
void scn_do_execute() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const override;
static bool check_layout(const TensorLayout &layout, int transpose);
MGB_DEFINE_OPR_CLASS(MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>,
public mixin::AlgoChooserHelper) // {
public:
using AlgorithmInfo = megdnn::detail::Algorithm::Info;
MatrixMul(VarNode* opr0, VarNode* opr1, const Param& param,
const ExecutionPolicy& policy, const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const Param& param = {},
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
private:
void add_input_layout_constraint() override;
void scn_do_execute() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes)
const override;
static bool check_layout(const TensorLayout& layout, int transpose);
//! store the policy of all transpose situations
megdnn::MatrixMul::ExecutionPolicy m_cadidate_execution_policies[4];
};
/*!
* \brief batched matrix multiplication on 3D inputs
*/
MGB_DEFINE_OPR_CLASS(BatchedMatrixMul,
intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>) // {
public:
BatchedMatrixMul(VarNode *opr0, VarNode *opr1,
const Param &param, const OperatorNodeConfig &config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const Param &param = {},
const OperatorNodeConfig &config = {});
private:
void add_input_layout_constraint() override;
void init_output_dtype() override;
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes,
const TensorShapeArray &output_shapes) const override;
static bool check_layout(const TensorLayout &layout, bool transpose);
intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>,
public mixin::AlgoChooserHelper) // {
public:
using AlgorithmInfo = megdnn::detail::Algorithm::Info;
BatchedMatrixMul(VarNode* opr0, VarNode* opr1, const Param& param,
const ExecutionPolicy& policy,
const OperatorNodeConfig& config);
static SymbolVar make(SymbolVar opr0, SymbolVar opr1,
const Param& param = {},
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
private:
void add_input_layout_constraint() override;
void init_output_dtype() override;
void scn_do_execute() override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes)
const override;
static bool check_layout(const TensorLayout& layout, bool transpose);
//! store the policy of all transpose situations
megdnn::BatchedMatrixMul::ExecutionPolicy m_cadidate_execution_policies[4];
};
/*!
......@@ -109,4 +117,3 @@ MGB_DEFINE_OPR_CLASS(SVD, intl::MegDNNOprWrapperFwd<megdnn::SVD>) // {
} // mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -14,6 +14,7 @@
#include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/blas.h"
template <class MegDNNOpr>
struct MegDNNOpr2MGBOpr;
......
......@@ -18,26 +18,31 @@
#include "megbrain/comp_node.h"
#include "megdnn/basic_types.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/nn.h"
namespace mgb {
namespace opr {
#define MGB_FOREACH_FASTRUN_OPR(cb) \
cb(ConvolutionForward); \
cb(ConvBiasForward); \
cb(ConvolutionBackwardData); \
cb(ConvolutionBackwardFilter); \
cb(Convolution3DForward); \
cb(Convolution3DBackwardData); \
cb(Convolution3DBackwardFilter); \
cb(LocalShareForward); \
cb(LocalShareBackwardData); \
cb(LocalShareBackwardFilter); \
cb(DeformableConvForward); \
cb(DeformableConvBackwardFilter); \
cb(DeformableConvBackwardData); \
cb(BatchConvBiasForward);
// clang-format off
#define MGB_FOREACH_FASTRUN_OPR(cb) \
cb(ConvolutionForward) \
cb(ConvBiasForward) \
cb(ConvolutionBackwardData) \
cb(ConvolutionBackwardFilter) \
cb(Convolution3DForward) \
cb(Convolution3DBackwardData) \
cb(Convolution3DBackwardFilter) \
cb(LocalShareForward) \
cb(LocalShareBackwardData) \
cb(LocalShareBackwardFilter) \
cb(DeformableConvForward) \
cb(DeformableConvBackwardFilter) \
cb(DeformableConvBackwardData) \
cb(BatchConvBiasForward) \
cb(MatrixMul) \
cb(BatchedMatrixMul)
// clang-format on
template <typename Opr>
struct OprArityTrait;
......@@ -67,6 +72,8 @@ INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1);
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1);
INST_ARITY(megdnn::ConvBias, 4, 1);
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3);
INST_ARITY(megdnn::MatrixMul, 2, 1);
INST_ARITY(megdnn::BatchedMatrixMul, 2, 1);
#undef INST_ARITY
......
......@@ -269,7 +269,7 @@ void run_trans_inp_test_case(bool trans_a, bool trans_b) {
if (DTypeTrait<dt_dst>::enumv == DTypeEnum::Int16) {
config.output_dtype(dtype::Int16());
}
auto z = opr::MatrixMul::make(x, y, {}, config);
auto z = opr::MatrixMul::make(x, y, {}, {}, config);
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
......@@ -359,7 +359,7 @@ void run_bgemm_trans_inp_test_case(bool trans_a, bool trans_b) {
trans_a ? (x = opr::Dimshuffle::make(x, {0, 2, 1})) : 0;
trans_b ? (y = opr::Dimshuffle::make(y, {0, 2, 1})) : 0;
auto z = opr::BatchedMatrixMul::make(x, y, {}, OperatorNodeConfig{});
auto z = opr::BatchedMatrixMul::make(x, y, {}, {}, OperatorNodeConfig{});
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
auto run = [&](size_t B, size_t M, size_t K, size_t N) {
......@@ -420,6 +420,43 @@ TEST(TestOprBlas, MatrixMul_TT) {
run_sgemm_test(true, true);
}
TEST(TestOprDNN, MatrixMulExePolicy) {
using Param = opr::MatrixMul::Param;
Param param;
using Policy = opr::MatrixMul::ExecutionPolicy;
using S = Policy::Strategy;
auto cn = CompNode::load("cpux");
#if MGB_ENABLE_FASTRUN
for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE,
S::PROFILE_HEURISTIC}) {
#else
for (auto strategy: {S:HEURISTIC, S::PROFILE_HEURISTIC}) {
#endif
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp), cn)
.rename(name);
};
auto A = mkvar("A", {32, 64});
auto B = mkvar("B", {64, 32});
Policy policy;
policy.strategy = strategy;
auto C = opr::MatrixMul::make(A, B, param, policy);
HostTensorND host_c;
auto func = graph->compile({make_callback_copy(C, host_c)});
func->execute();
}
}
TEST(TestOprBlas, BatchedMatrixMulFp32_NN) {
run_batched_sgemm_test(false, false);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册