提交 935fe781 编写于 作者: M Megvii Engine Team

chore(megbrain): format sereg.h code

GitOrigin-RevId: 0fe1cf6a8b52e9f9e4f53de65434c965d5d272d7
上级 55042195
......@@ -28,579 +28,561 @@
namespace mgb {
namespace serialization {
template <class MegDNNPooling = megdnn::Pooling>
struct MakePoolingCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, config).node();
}
return nullptr;
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlign>
struct MakeROIAlignCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0],inputs[1], param, config).node();
} else {
return nullptr;
}
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
struct MakeROIAlignCaller4 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node();
} else {
return nullptr;
}
}
};
template <class MegDNNPooling = megdnn::PoolingBackward>
struct MakePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node();
}
return nullptr;
}
};
template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
struct MakeAdaptivePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node();
}
return nullptr;
template <class MegDNNPooling = megdnn::Pooling>
struct MakePoolingCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, config).node();
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &inputs,
const typename MegDNNConv::Param &param,
const megdnn::param::ExecutionPolicy &execution_policy,
const OperatorNodeConfig &config) {
if (inputs.size() == 2) {
return Opr::make(
inputs[0], inputs[1], param,
execution_policy, config).node();
}
return nullptr;
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlign>
struct MakeROIAlignCaller1 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, config).node();
} else {
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &inputs,
const typename MegDNNConv::Param &param,
const megdnn::param::ExecutionPolicy &execution_policy,
const OperatorNodeConfig &config) {
if (inputs.size() == 3) {
return Opr::make(
inputs[0], inputs[1], inputs[2], param,
execution_policy, config).node();
}
}
};
template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
struct MakeROIAlignCaller4 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNROIALIGN::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
config)
.node();
} else {
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &inputs,
const typename MegDNNConv::Param &param,
const megdnn::param::ExecutionPolicy &execution_policy,
const OperatorNodeConfig &config) {
if (inputs.size() == 4) {
return Opr::make(
inputs[0], inputs[1], inputs[2], inputs[3], param,
execution_policy, config).node();
}
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 5) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
inputs[4], param, execution_policy, config)
.node();
}
return nullptr;
}
};
template<class MegDNNConv = megdnn::Convolution>
struct MakeConvCallerEmpty {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &,
const typename MegDNNConv::Param &,
const megdnn::param::ExecutionPolicy &,
const OperatorNodeConfig &) {
return nullptr;
}
};
template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeConvCallerEmpty<MegDNNConv>,
class Maker2=MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution >
struct ConvLoadDumpImpl {
static void dump(OprDumpContext &ctx,
const cg::OperatorNodeBase &opr_) {
auto &&opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param,
execution_policy, config);
}
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param,
execution_policy, config);
}
mgb_assert(ret);
return ret;
}
};
template <class MegDNNPooling = megdnn::PoolingBackward>
struct MakePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node();
}
static cg::OperatorNodeBase* load(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto param = ctx.read_param<ConvParam>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
return nullptr;
}
};
template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
struct MakeAdaptivePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNPooling::Param& param,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
config)
.node();
}
};
template <class Opr, class Maker0,
typename PoolingParam = megdnn::param::Pooling>
struct PoolingLoadDumpImpl {
static void dump(OprDumpContext& ctx,
const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<PoolingParam>(opr.param());
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, execution_policy,
config)
.node();
}
static VarNode* make(
const cg::VarNodeArray& inputs, const PoolingParam& param,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
config);
mgb_assert(ret);
return ret;
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param,
execution_policy, config)
.node();
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<PoolingParam>();
return make(inputs, param, config)->owner_opr();
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 4) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
execution_policy, config)
.node();
}
};
template<>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>:
public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward,
MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
megdnn::param::AdaptivePooling>
{};
template<>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>:
public PoolingLoadDumpImpl<opr::AdaptivePooling,
MakeROIAlignCaller1<megdnn::AdaptivePooling>,
megdnn::param::AdaptivePooling>
{};
template<>
struct OprLoadDumpImpl<opr::ROIAlign, 0>:
public PoolingLoadDumpImpl<opr::ROIAlign,
MakeROIAlignCaller1<megdnn::ROIAlign>,
megdnn::param::ROIAlign>
{};
template<>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>:
public PoolingLoadDumpImpl<opr::ROIAlignBackward,
MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
megdnn::param::ROIAlign>
{};
template<>
struct OprLoadDumpImpl<opr::Pooling, 0>:
public PoolingLoadDumpImpl<opr::Pooling,
MakePoolingCaller1<megdnn::Pooling>,
megdnn::param::Pooling>
{};
template<>
struct OprLoadDumpImpl<opr::PoolingBackward, 0>:
public PoolingLoadDumpImpl<opr::PoolingBackward,
MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
megdnn::param::Pooling>
{};
template<>
struct OprLoadDumpImpl<opr::Convolution, 0>:
public ConvLoadDumpImpl<opr::Convolution,
MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution>
{};
template<>
struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>:
public ConvLoadDumpImpl<opr::ConvolutionBackwardData,
MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution,
MakeConvCaller3<megdnn::Convolution> >
{};
template<>
struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>:
public ConvLoadDumpImpl<opr::ConvolutionBackwardFilter,
MakeConvCaller3<megdnn::Convolution>,
megdnn::Convolution>
{};
template<>
struct OprLoadDumpImpl<opr::Convolution3D, 0>:
public ConvLoadDumpImpl<opr::Convolution3D,
MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D>
{};
template<>
struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>:
public ConvLoadDumpImpl<opr::Convolution3DBackwardData,
MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCaller3<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D>
{};
template<>
struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>:
public ConvLoadDumpImpl<opr::Convolution3DBackwardFilter,
MakeConvCaller3<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D>
{};
template<>
struct OprLoadDumpImpl<opr::ConvBiasForward, 0>:
public ConvLoadDumpImpl<opr::ConvBiasForward,
MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward,
MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>,
megdnn::param::ConvBias>
{};
template <>
struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
: public ConvLoadDumpImpl<
opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
template <>
struct OprMaker<opr::BatchNorm, 0> {
using Param = opr::BatchNorm::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::BatchNorm::make(i[0], i[1], i[2],
param, config)[0].node()->owner_opr();
} else {
mgb_assert(i.size() == 5);
return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4],
param, config)[0].node()->owner_opr();
}
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 5) {
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
inputs[4], param, execution_policy, config)
.node();
}
};
template <>
struct OprMaker<opr::BatchNormBackward, 5> {
using Param = opr::BatchNormBackward::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4],
param, config)[0].node()->owner_opr();
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCallerEmpty {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray&,
const typename MegDNNConv::Param&,
const megdnn::param::ExecutionPolicy&,
const OperatorNodeConfig&) {
return nullptr;
}
};
template <class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution>
struct ConvLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<ConvParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}
static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param, execution_policy,
config);
}
};
template <>
struct OprMaker<opr::TQTBackward, 3> {
using Param = opr::TQTBackward::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param, execution_policy,
config);
}
};
template<class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &inputs,
const typename MegDNNConv::Param &param,
const megdnn::param::ExecutionPolicy &execution_policy,
const OperatorNodeConfig &config) {
if (inputs.size() == 2) {
return Opr::make(
inputs[0], inputs[1], param,
execution_policy, config).node();
}
return nullptr;
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<ConvParam>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
}
};
template <class Opr, class Maker0,
typename PoolingParam = megdnn::param::Pooling>
struct PoolingLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<PoolingParam>(opr.param());
}
static VarNode* make(const cg::VarNodeArray& inputs,
const PoolingParam& param,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<PoolingParam>();
return make(inputs, param, config)->owner_opr();
}
};
template <>
struct OprMaker<opr::TQTBackward, 3> {
using Param = opr::TQTBackward::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
: public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward,
MakeAdaptivePoolingBackwardCaller3<
megdnn::AdaptivePoolingBackward>,
megdnn::param::AdaptivePooling> {};
template <>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
: public PoolingLoadDumpImpl<
opr::AdaptivePooling,
MakeROIAlignCaller1<megdnn::AdaptivePooling>,
megdnn::param::AdaptivePooling> {};
template <>
struct OprLoadDumpImpl<opr::ROIAlign, 0>
: public PoolingLoadDumpImpl<opr::ROIAlign,
MakeROIAlignCaller1<megdnn::ROIAlign>,
megdnn::param::ROIAlign> {};
template <>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
: public PoolingLoadDumpImpl<
opr::ROIAlignBackward,
MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
megdnn::param::ROIAlign> {};
template <>
struct OprLoadDumpImpl<opr::Pooling, 0>
: public PoolingLoadDumpImpl<opr::Pooling,
MakePoolingCaller1<megdnn::Pooling>,
megdnn::param::Pooling> {};
template <>
struct OprLoadDumpImpl<opr::PoolingBackward, 0>
: public PoolingLoadDumpImpl<
opr::PoolingBackward,
MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
megdnn::param::Pooling> {};
template <>
struct OprLoadDumpImpl<opr::Convolution, 0>
: public ConvLoadDumpImpl<opr::Convolution,
MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>
: public ConvLoadDumpImpl<opr::ConvolutionBackwardData,
MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution,
MakeConvCaller3<megdnn::Convolution> > {};
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>
: public ConvLoadDumpImpl<opr::ConvolutionBackwardFilter,
MakeConvCaller3<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::Convolution3D, 0>
: public ConvLoadDumpImpl<opr::Convolution3D,
MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>
: public ConvLoadDumpImpl<opr::Convolution3DBackwardData,
MakeConvCaller2<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCaller3<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
: public ConvLoadDumpImpl<opr::Convolution3DBackwardFilter,
MakeConvCaller3<megdnn::Convolution3D>,
megdnn::Convolution3D,
MakeConvCallerEmpty<megdnn::Convolution3D>,
MakeConvCallerEmpty<megdnn::Convolution3D>,
megdnn::param::Convolution3D> {};
template <>
struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
: public ConvLoadDumpImpl<opr::ConvBiasForward,
MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward,
MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>,
megdnn::param::ConvBias> {};
template <>
struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
: public ConvLoadDumpImpl<opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
template <>
struct OprMaker<opr::BatchNorm, 0> {
using Param = opr::BatchNorm::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 5);
return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param,
config)[0]
.node()
->owner_opr();
}
};
template<class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller3 {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &inputs,
const typename MegDNNConv::Param &param,
const megdnn::param::ExecutionPolicy &execution_policy,
const OperatorNodeConfig &config) {
if (inputs.size() == 3) {
return Opr::make(
inputs[0], inputs[1], inputs[2], param,
execution_policy, config).node();
}
return nullptr;
}
};
template <>
struct OprMaker<opr::BatchNormBackward, 5> {
using Param = opr::BatchNormBackward::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& i,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param,
config)[0]
.node()
->owner_opr();
}
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, execution_policy,
config)
.node();
}
};
template<class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCallerEmpty {
template<typename Opr>
static VarNode* make(const cg::VarNodeArray &,
const typename MegDNNConv::Param &,
const megdnn::param::ExecutionPolicy &,
const OperatorNodeConfig &) {
return nullptr;
return nullptr;
}
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller3 {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray& inputs,
const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param,
execution_policy, config)
.node();
}
};
template<class Opr, class Maker0, class MegDNNConv,
class Maker1=MakeLocalShareCallerEmpty<MegDNNConv>,
class Maker2=MakeLocalShareCallerEmpty<MegDNNConv>,
typename LocalShareParam = megdnn::param::LocalShare >
struct LocalShareLoadDumpImpl {
static void dump(OprDumpContext &ctx,
const cg::OperatorNodeBase &opr_) {
auto &&opr = opr_.cast_final_safe<Opr>();
ctx.write_param<LocalShareParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(
opr.execution_policy());
return nullptr;
}
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCallerEmpty {
template <typename Opr>
static VarNode* make(const cg::VarNodeArray&,
const typename MegDNNConv::Param&,
const megdnn::param::ExecutionPolicy&,
const OperatorNodeConfig&) {
return nullptr;
}
};
template <class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeLocalShareCallerEmpty<MegDNNConv>,
class Maker2 = MakeLocalShareCallerEmpty<MegDNNConv>,
typename LocalShareParam = megdnn::param::LocalShare>
struct LocalShareLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<LocalShareParam>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}
static VarNode* make(const cg::VarNodeArray& inputs,
const LocalShareParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param, execution_policy,
config);
}
static VarNode* make(
const cg::VarNodeArray& inputs, const LocalShareParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param,
execution_policy, config);
if (!ret) {
ret = Maker1::template make<Opr>(inputs, param,
execution_policy, config);
}
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param,
execution_policy, config);
}
mgb_assert(ret);
return ret;
if (!ret) {
ret = Maker2::template make<Opr>(inputs, param, execution_policy,
config);
}
static cg::OperatorNodeBase* load(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto param = ctx.read_param<LocalShareParam>();
auto execution_policy =
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<LocalShareParam>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::LocalShare, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShare,
MakeLocalShareCaller2<megdnn::LocalShare>,
megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShareBackwardData,
MakeLocalShareCaller3<megdnn::LocalShare>,
megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShareBackwardFilter,
MakeLocalShareCaller3<megdnn::LocalShare>,
megdnn::LocalShare> {};
template<>
struct OprLoadDumpImpl<opr::DeformableConvForward, 0>:
public ConvLoadDumpImpl<opr::DeformableConvForward,
MakeConvCaller4<megdnn::DeformableConvForward>,
megdnn::Convolution>
{};
template<>
struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>:
public ConvLoadDumpImpl<opr::DeformableConvBackwardData,
MakeConvCaller5<megdnn::DeformableConvBackwardData>,
megdnn::Convolution>
{};
template<>
struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>:
public ConvLoadDumpImpl<opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution>
{};
} // namespace serialization
return make(inputs, param, execution_policy, config)->owner_opr();
}
};
template <>
struct OprLoadDumpImpl<opr::LocalShare, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShareBackwardData,
MakeLocalShareCaller3<megdnn::LocalShare>,
megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
: public LocalShareLoadDumpImpl<
opr::LocalShareBackwardFilter,
MakeLocalShareCaller3<megdnn::LocalShare>,
megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::DeformableConvForward, 0>
: public ConvLoadDumpImpl<
opr::DeformableConvForward,
MakeConvCaller4<megdnn::DeformableConvForward>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>
: public ConvLoadDumpImpl<
opr::DeformableConvBackwardData,
MakeConvCaller5<megdnn::DeformableConvBackwardData>,
megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
: public ConvLoadDumpImpl<
opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {};
} // namespace serialization
namespace opr {
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
MGB_SEREG_OPR(ConvolutionV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0);
MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);
using LocalV1 = Local;
using LocalBackwardDataV1 = LocalBackwardData;
using LocalBackwardFilterV1 = LocalBackwardFilter;
MGB_SEREG_OPR(LocalV1, 2);
MGB_SEREG_OPR(LocalBackwardDataV1, 3);
MGB_SEREG_OPR(LocalBackwardFilterV1, 3);
using GroupLocalV1 = GroupLocal;
using GroupLocalBackwardDataV1 = GroupLocalBackwardData;
using GroupLocalBackwardFilterV1 = GroupLocalBackwardFilter;
MGB_SEREG_OPR(GroupLocalV1, 2);
MGB_SEREG_OPR(GroupLocalBackwardDataV1, 3);
MGB_SEREG_OPR(GroupLocalBackwardFilterV1, 3);
MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
MGB_SEREG_OPR(PoolingV1, 1);
MGB_SEREG_OPR(PoolingBackwardV1, 3);
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);
MGB_SEREG_OPR(ROIPooling, 3);
MGB_SEREG_OPR(ROIPoolingBackward, 4);
using MaskConvolutionV1 = MaskConvolution;
MGB_SEREG_OPR(MaskConvolutionV1, 3);
MGB_SEREG_OPR(MaskPropagate, 1);
MGB_SEREG_OPR(Convolution3D, 0);
MGB_SEREG_OPR(Convolution3DBackwardData, 0);
MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 5);
using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
MGB_SEREG_OPR(LocalShareForwardV1, 0);
MGB_SEREG_OPR(LocalShareBackwardDataV1, 0);
MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0);
using ROIAlignV1=ROIAlign;
using ROIAlignBackwardV1=ROIAlignBackward;
MGB_SEREG_OPR(ROIAlignV1, 2);
MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
MGB_SEREG_OPR(DeformableConvForward, 0);
MGB_SEREG_OPR(DeformableConvBackwardData, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilter, 0);
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
using BatchConvBiasForwardV1 = BatchConvBiasForward;
MGB_SEREG_OPR(BatchConvBiasForwardV1, 0);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
} // namespace opr
} // namespace mgb
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
MGB_SEREG_OPR(ConvolutionV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0);
MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);
using LocalV1 = Local;
using LocalBackwardDataV1 = LocalBackwardData;
using LocalBackwardFilterV1 = LocalBackwardFilter;
MGB_SEREG_OPR(LocalV1, 2);
MGB_SEREG_OPR(LocalBackwardDataV1, 3);
MGB_SEREG_OPR(LocalBackwardFilterV1, 3);
using GroupLocalV1 = GroupLocal;
using GroupLocalBackwardDataV1 = GroupLocalBackwardData;
using GroupLocalBackwardFilterV1 = GroupLocalBackwardFilter;
MGB_SEREG_OPR(GroupLocalV1, 2);
MGB_SEREG_OPR(GroupLocalBackwardDataV1, 3);
MGB_SEREG_OPR(GroupLocalBackwardFilterV1, 3);
MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
MGB_SEREG_OPR(PoolingV1, 1);
MGB_SEREG_OPR(PoolingBackwardV1, 3);
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);
MGB_SEREG_OPR(ROIPooling, 3);
MGB_SEREG_OPR(ROIPoolingBackward, 4);
using MaskConvolutionV1 = MaskConvolution;
MGB_SEREG_OPR(MaskConvolutionV1, 3);
MGB_SEREG_OPR(MaskPropagate, 1);
MGB_SEREG_OPR(Convolution3D, 0);
MGB_SEREG_OPR(Convolution3DBackwardData, 0);
MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR(BatchNorm, 0);
MGB_SEREG_OPR(BatchNormBackward, 5);
using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
MGB_SEREG_OPR(LocalShareForwardV1, 0);
MGB_SEREG_OPR(LocalShareBackwardDataV1, 0);
MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0);
using ROIAlignV1 = ROIAlign;
using ROIAlignBackwardV1 = ROIAlignBackward;
MGB_SEREG_OPR(ROIAlignV1, 2);
MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
MGB_SEREG_OPR(DeformableConvForward, 0);
MGB_SEREG_OPR(DeformableConvBackwardData, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilter, 0);
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
using BatchConvBiasForwardV1 = BatchConvBiasForward;
MGB_SEREG_OPR(BatchConvBiasForwardV1, 0);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
} // namespace opr
} // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -9,195 +9,194 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <type_traits>
//#include <type_traits>
#include "megbrain/opr/imgproc.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
namespace mgb {
namespace serialization {
//! OprMaker implementation for operators with variadic arguments
template<>
struct OprMaker<opr::WarpPerspective, 0> {
using Opr = opr::WarpPerspective;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node()
->owner_opr();
}
//! OprMaker implementation for operators with variadic arguments
template <>
struct OprMaker<opr::WarpPerspective, 0> {
using Opr = opr::WarpPerspective;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
config)
.node()
->owner_opr();
}
};
template <>
struct OprMaker<opr::Remap, 0> {
using Opr = opr::Remap;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::Remap, 0> {
using Opr = opr::Remap;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 2) {
return Opr::make(inputs[0], inputs[1], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
};
template<>
struct OprMaker<opr::RemapBackwardMat, 0> {
using Opr = opr::RemapBackwardMat;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::RemapBackwardMat, 0> {
using Opr = opr::RemapBackwardMat;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
};
template<>
struct OprMaker<opr::RemapBackwardData, 0> {
using Opr = opr::RemapBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
}
};
template <>
struct OprMaker<opr::RemapBackwardData, 0> {
using Opr = opr::RemapBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
return nullptr;
}
};
template <>
struct OprMaker<opr::DctChannelSelectForward, 0> {
using Opr = opr::DctChannelSelectForward;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 1);
return Opr::make(inputs[0], param, config).node()->owner_opr();
}
}
};
template <>
struct OprMaker<opr::DctChannelSelectForward, 0> {
using Opr = opr::DctChannelSelectForward;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 1);
return Opr::make(inputs[0], param, config).node()->owner_opr();
}
};
template<>
struct OprMaker<opr::WarpPerspectiveBackwardData, 0> {
using Opr = opr::WarpPerspectiveBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node()
->owner_opr();
}
}
};
template <>
struct OprMaker<opr::WarpPerspectiveBackwardData, 0> {
using Opr = opr::WarpPerspectiveBackwardData;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
config)
.node()
->owner_opr();
}
};
template<>
struct OprMaker<opr::WarpPerspectiveBackwardMat, 0> {
using Opr = opr::WarpPerspectiveBackwardMat;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3],
param, config)
.node()
->owner_opr();
}
}
};
template <>
struct OprMaker<opr::WarpPerspectiveBackwardMat, 0> {
using Opr = opr::WarpPerspectiveBackwardMat;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config)
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 4);
return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
config)
.node()
->owner_opr();
}
};
} // namespace serialization
}
};
} // namespace serialization
namespace opr {
using WarpPerspectiveV2=WarpPerspective;
using WarpPerspectiveBackwardDataV2=WarpPerspectiveBackwardData;
using WarpPerspectiveBackwardMatV2=WarpPerspectiveBackwardMat;
MGB_SEREG_OPR(WarpPerspectiveV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardDataV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardMatV2, 0);
MGB_SEREG_OPR(Rotate, 1);
MGB_SEREG_OPR(CvtColor, 1);
MGB_SEREG_OPR(GaussianBlur, 1);
MGB_SEREG_OPR(ResizeBackward, 2);
using RemapV1=Remap;
using RemapBackwardDataV1=RemapBackwardData;
using RemapBackwardMatV1=RemapBackwardMat;
MGB_SEREG_OPR(RemapV1, 2);
MGB_SEREG_OPR(RemapBackwardDataV1, 3);
MGB_SEREG_OPR(RemapBackwardMatV1, 3);
//! current warp affine version
using WarpAffineV2 = opr::WarpAffine;
MGB_SEREG_OPR(WarpAffineV2, 3);
//! current resize version
using ResizeV2 = opr::Resize;
MGB_SEREG_OPR(ResizeV2, 2);
using DctChannelSelectV1 = opr::DctChannelSelect;
MGB_SEREG_OPR(DctChannelSelectV1, 0);
} // namespace opr
} // namespace mgb
using WarpPerspectiveV2 = WarpPerspective;
using WarpPerspectiveBackwardDataV2 = WarpPerspectiveBackwardData;
using WarpPerspectiveBackwardMatV2 = WarpPerspectiveBackwardMat;
MGB_SEREG_OPR(WarpPerspectiveV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardDataV2, 0);
MGB_SEREG_OPR(WarpPerspectiveBackwardMatV2, 0);
MGB_SEREG_OPR(Rotate, 1);
MGB_SEREG_OPR(CvtColor, 1);
MGB_SEREG_OPR(GaussianBlur, 1);
MGB_SEREG_OPR(ResizeBackward, 2);
using RemapV1 = Remap;
using RemapBackwardDataV1 = RemapBackwardData;
using RemapBackwardMatV1 = RemapBackwardMat;
MGB_SEREG_OPR(RemapV1, 2);
MGB_SEREG_OPR(RemapBackwardDataV1, 3);
MGB_SEREG_OPR(RemapBackwardMatV1, 3);
//! current warp affine version
using WarpAffineV2 = opr::WarpAffine;
MGB_SEREG_OPR(WarpAffineV2, 3);
//! current resize version
using ResizeV2 = opr::Resize;
MGB_SEREG_OPR(ResizeV2, 2);
using DctChannelSelectV1 = opr::DctChannelSelect;
MGB_SEREG_OPR(DctChannelSelectV1, 0);
} // namespace opr
} // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册