提交 b8ccc6a2 编写于 作者: M Megvii Engine Team

fix(mgb): fix loss execution policy after opr shallow copy

GitOrigin-RevId: 4738136e4a8a3270a2b5343483539f46afe05c6d
上级 c27f6782
......@@ -183,7 +183,7 @@ namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), config);
return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
}
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
} // namespace pooling
......
......@@ -63,7 +63,7 @@ def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, Executio
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;
def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>;
def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>;
def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>;
......
......@@ -31,7 +31,21 @@ using namespace opr;
namespace {
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
struct MakeOprWithPolicyCaller1 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, execution_policy, config).node();
}
return nullptr;
}
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeOprWithPolicyCaller2 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
......@@ -46,7 +60,7 @@ struct MakeConvCaller2 {
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
struct MakeOprWithPolicyCaller3 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
......@@ -63,7 +77,7 @@ struct MakeConvCaller3 {
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
struct MakeOprWithPolicyCaller4 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
......@@ -80,7 +94,7 @@ struct MakeConvCaller4 {
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
struct MakeOprWithPolicyCaller5 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
......@@ -97,7 +111,7 @@ struct MakeConvCaller5 {
};
template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCallerEmpty {
struct MakeOprWithPolicyCallerEmpty {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray&, const typename MegDNNConv::Param&,
......@@ -108,10 +122,10 @@ struct MakeConvCallerEmpty {
template <
class Opr, class Maker0, class MegDNNConv,
class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
typename ConvParam = megdnn::param::Convolution>
struct ConvMakerImpl {
class Maker1 = MakeOprWithPolicyCallerEmpty<MegDNNConv>,
class Maker2 = MakeOprWithPolicyCallerEmpty<MegDNNConv>,
typename ConvParam = typename MegDNNConv::Param>
struct OprWithPolicyMakerImpl {
static VarNode* make(
const cg::VarNodeArray& inputs, const ConvParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
......@@ -130,33 +144,43 @@ struct ConvMakerImpl {
};
template <typename Opr>
struct ConvMaker;
struct OprWithPolicyMaker;
template <>
struct OprWithPolicyMaker<opr::Pooling>
: public OprWithPolicyMakerImpl<
opr::Pooling, MakeOprWithPolicyCaller1<megdnn::Pooling>,
megdnn::Pooling> {};
template <>
struct ConvMaker<opr::Convolution>
: public ConvMakerImpl<
opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
struct OprWithPolicyMaker<opr::Convolution>
: public OprWithPolicyMakerImpl<
opr::Convolution, MakeOprWithPolicyCaller2<megdnn::Convolution>,
megdnn::Convolution> {};
template <>
struct ConvMaker<opr::ConvolutionBackwardData>
: public ConvMakerImpl<
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
struct OprWithPolicyMaker<opr::ConvolutionBackwardData>
: public OprWithPolicyMakerImpl<
opr::ConvolutionBackwardData,
MakeOprWithPolicyCaller2<megdnn::Convolution>, megdnn::Convolution,
MakeOprWithPolicyCaller3<megdnn::Convolution>> {};
template <>
struct ConvMaker<opr::ConvBiasForward>
: public ConvMakerImpl<
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
struct OprWithPolicyMaker<opr::ConvBiasForward>
: public OprWithPolicyMakerImpl<
opr::ConvBiasForward,
MakeOprWithPolicyCaller2<megdnn::ConvBiasForward>,
megdnn::ConvBiasForward,
MakeOprWithPolicyCaller3<megdnn::ConvBiasForward>,
MakeOprWithPolicyCaller4<megdnn::ConvBiasForward>,
megdnn::param::ConvBias> {};
template <>
struct ConvMaker<opr::BatchConvBiasForward>
: public ConvMakerImpl<
struct OprWithPolicyMaker<opr::BatchConvBiasForward>
: public OprWithPolicyMakerImpl<
opr::BatchConvBiasForward,
MakeConvCaller2<megdnn::BatchConvBiasForward>,
MakeOprWithPolicyCaller2<megdnn::BatchConvBiasForward>,
megdnn::BatchConvBiasForward,
MakeConvCaller3<megdnn::BatchConvBiasForward>,
MakeConvCaller4<megdnn::BatchConvBiasForward>,
MakeOprWithPolicyCaller3<megdnn::BatchConvBiasForward>,
MakeOprWithPolicyCaller4<megdnn::BatchConvBiasForward>,
megdnn::param::BatchConvBias> {};
#include "../../opr/impl/internal/invoke.h"
......@@ -254,7 +278,7 @@ struct OprFormatModifier;
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return ConvMaker<_Opr>::make( \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy(), opr.config()); \
MIDOUT_E \
} \
......@@ -263,6 +287,7 @@ INST(Convolution);
INST(ConvBiasForward);
INST(ConvolutionBackwardData);
INST(BatchConvBiasForward);
INST(Pooling);
#undef INST
template <>
......@@ -303,7 +328,6 @@ struct OprFormatModifier<WarpPerspective> {
MIDOUT_E \
} \
};
INST(PoolingForward, 1);
INST(Resize, 2);
#undef INST
......
......@@ -1492,7 +1492,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
}
auto new_param = pooling_opr.param();
new_param.format = megdnn::param::Pooling::Format::NHWCD4;
auto new_pooling_opr = opr::PoolingForward::make(inp, new_param, opr->config());
auto new_pooling_opr = opr::PoolingForward::make(
inp, new_param, pooling_opr.execution_policy(), opr->config());
return new_pooling_opr.node()->owner_opr();
};
......
......@@ -525,8 +525,8 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
}
auto new_param = pooling.param();
new_param.format = Format::NCHW32;
auto new_pooling =
opr::PoolingForward::make(new_inp_var, new_param, opr->config());
auto new_pooling = opr::PoolingForward::make(
new_inp_var, new_param, pooling.execution_policy(), opr->config());
return new_pooling.node()->owner_opr();
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
......@@ -795,8 +795,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
if (varshape_changed.count(new_inp[0])) {
auto new_param = pooling.param();
new_param.format = Format::CHWN4;
auto new_pooling =
opr::PoolingForward::make(new_inp[0], new_param, opr->config());
auto new_pooling = opr::PoolingForward::make(
new_inp[0], new_param, pooling.execution_policy(), opr->config());
varshape_changed.insert(new_pooling.node());
return new_pooling.node()->owner_opr();
}
......@@ -1174,8 +1174,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = pooling.param();
new_param.format = Format::NCHW4;
auto new_pooling =
opr::PoolingForward::make(new_inp[0], new_param, opr->config());
auto new_pooling = opr::PoolingForward::make(
new_inp[0], new_param, pooling.execution_policy(), opr->config());
mgb_assert(
new_pooling.shape().ndim == 5,
"out var of Pooling opr after transform must be 5 (got: "
......@@ -1646,8 +1646,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
if (inp->shape().ndim == 5) {
auto new_param = pooling_opr.param();
new_param.format = pooling_format;
auto new_pooling_opr =
opr::PoolingForward::make(inp, new_param, opr->config());
auto new_pooling_opr = opr::PoolingForward::make(
inp, new_param, pooling_opr.execution_policy(), opr->config());
mgb_assert(
new_pooling_opr.shape().ndim == 5,
"The pooling dst dim is not trans to nchwxx");
......@@ -3003,7 +3003,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = pooling.param();
param.format = target_format;
auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config());
auto new_pool = opr::PoolingForward::make(
inps[0], param, pooling.execution_policy(), pooling.config());
auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, target_format));
return ret;
......@@ -3055,7 +3056,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto param = pooling.param();
param.format = out_format;
auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config());
auto new_pool = opr::PoolingForward::make(
inps[0], param, pooling.execution_policy(), pooling.config());
auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, out_format));
return ret;
......
此差异由.gitattributes 抑制。
......@@ -281,7 +281,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) {
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 3u);
ASSERT_EQ(nr_dimshuffle, 5u);
/// check pass fuse conv bias with z
auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var);
ASSERT_EQ(nr_elemwise_mult_type, 4u);
......@@ -822,7 +822,7 @@ TEST(TestLayoutTransform, Resnet18_F16) {
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 4u);
ASSERT_EQ(nr_dimshuffle, 2u);
/// check pass fuse conv bias with z
auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var);
ASSERT_EQ(nr_elemwise, 4u);
......
......@@ -80,14 +80,26 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
opr::BatchedMatrixMul, MakeMatrixMulCaller<megdnn::BatchedMatrixMul>,
megdnn::BatchedMatrixMul> {};
template <typename Opr>
cg::OperatorNodeBase* opr_shallow_copy_matmul(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(ctx);
auto&& opr = opr_.cast_final_safe<Opr>();
return OprLoadDumpImpl<Opr, 2>::make(
inputs, opr.param(), opr.execution_policy_transient(), config)
->owner_opr();
}
} // namespace serialization
namespace opr {
using MatrixMulV2 = MatrixMul;
using BatchedMatrixMulV2 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV2, 2);
MGB_SEREG_OPR(BatchedMatrixMulV2, 2);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(MatrixMulV2, 2, opr_shallow_copy_matmul);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmul);
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);
......
......@@ -36,9 +36,10 @@ struct MakePoolingCaller1 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 1) {
return Opr::make(inputs[0], param, config).node();
return Opr::make(inputs[0], param, execution_policy, config).node();
}
return nullptr;
}
......@@ -78,9 +79,13 @@ struct MakePoolingBackwardCaller3 {
template <typename Opr>
static VarNode* make(
const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
if (inputs.size() == 3) {
return Opr::make(inputs[0], inputs[1], inputs[2], param, config).node();
return Opr::make(
inputs[0], inputs[1], inputs[2], param, execution_policy,
config)
.node();
}
return nullptr;
}
......@@ -223,8 +228,10 @@ struct PoolingLoadDumpImpl {
static VarNode* make(
const cg::VarNodeArray& inputs, const PoolingParam& param,
const megdnn::param::ExecutionPolicy& execution_policy,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
VarNode* ret =
Maker0::template make<Opr>(inputs, param, execution_policy, config);
mgb_assert(ret);
return ret;
}
......@@ -233,6 +240,29 @@ struct PoolingLoadDumpImpl {
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<PoolingParam>();
return make(inputs, param, {}, config)->owner_opr();
}
};
template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign>
struct GeneralOprLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<GeneralOprParam>(opr.param());
}
static VarNode* make(
const cg::VarNodeArray& inputs, const GeneralOprParam& param,
const OperatorNodeConfig& config) {
VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
mgb_assert(ret);
return ret;
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<GeneralOprParam>();
return make(inputs, param, config)->owner_opr();
}
};
......@@ -264,26 +294,26 @@ struct OprMaker<opr::LSQBackward, 5> {
};
template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
: public PoolingLoadDumpImpl<
: public GeneralOprLoadDumpImpl<
opr::AdaptivePoolingBackward,
MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
megdnn::param::AdaptivePooling> {};
template <>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
: public PoolingLoadDumpImpl<
: public GeneralOprLoadDumpImpl<
opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>,
megdnn::param::AdaptivePooling> {};
template <>
struct OprLoadDumpImpl<opr::ROIAlign, 0>
: public PoolingLoadDumpImpl<
: public GeneralOprLoadDumpImpl<
opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>,
megdnn::param::ROIAlign> {};
template <>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
: public PoolingLoadDumpImpl<
: public GeneralOprLoadDumpImpl<
opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
megdnn::param::ROIAlign> {};
......@@ -500,15 +530,29 @@ struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
opr::DeformableConvBackwardFilter,
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
megdnn::Convolution> {};
template <typename Opr>
cg::OperatorNodeBase* opr_shallow_copy_conv(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(ctx);
auto&& opr = opr_.cast_final_safe<Opr>();
return OprLoadDumpImpl<Opr, 0>::make(
inputs, opr.param(), opr.execution_policy_transient(), config)
->owner_opr();
}
} // namespace serialization
namespace opr {
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
MGB_SEREG_OPR(ConvolutionV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0);
MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);
......@@ -534,8 +578,8 @@ MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
MGB_SEREG_OPR(PoolingV1, 1);
MGB_SEREG_OPR(PoolingBackwardV1, 3);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv);
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
......@@ -548,12 +592,13 @@ using MaskConvolutionV2 = MaskConvolution;
MGB_SEREG_OPR(MaskConvolutionV2, 3);
MGB_SEREG_OPR(MaskPropagate, 1);
MGB_SEREG_OPR(Convolution3D, 0);
MGB_SEREG_OPR(Convolution3DBackwardData, 0);
MGB_SEREG_OPR(Convolution3DBackwardFilter, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
Convolution3DBackwardFilter, 0, opr_shallow_copy_conv);
using ConvBiasForwardV4 = ConvBiasForward;
MGB_SEREG_OPR(ConvBiasForwardV4, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv);
using BatchNormV1 = BatchNorm;
using BatchNormBackwardV1 = BatchNormBackward;
......@@ -563,9 +608,10 @@ MGB_SEREG_OPR(BatchNormBackwardV1, 6);
using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
MGB_SEREG_OPR(LocalShareForwardV1, 0);
MGB_SEREG_OPR(LocalShareBackwardDataV1, 0);
MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv);
using ROIAlignV1 = ROIAlign;
using ROIAlignBackwardV1 = ROIAlignBackward;
......@@ -574,9 +620,11 @@ MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
using DeformableConvForwardV1 = DeformableConvForward;
using DeformableConvBackwardDataV1 = DeformableConvBackwardData;
using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter;
MGB_SEREG_OPR(DeformableConvForwardV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR(CorrelationForward, 2);
MGB_SEREG_OPR(CorrelationBackwardData1, 3);
......@@ -586,7 +634,7 @@ MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);
using BatchConvBiasForwardV1 = BatchConvBiasForward;
MGB_SEREG_OPR(BatchConvBiasForwardV1, 0);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
......
......@@ -32,8 +32,8 @@ PoolingForward::PoolingForward(
}
SymbolVar PoolingForward::make(
SymbolVar i0, const Param& param, const OperatorNodeConfig& config,
const ExecutionPolicy& policy) {
SymbolVar i0, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0});
return i0.insert_single_output_opr<PoolingForward>(
i0.node(), param, policy, config);
......@@ -75,12 +75,13 @@ PoolingBackward::PoolingBackward(
0, true) {
init_megdnn_opr(*this, param);
add_input({i0, i1, i2});
m_policy = policy;
intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this);
}
SymbolVar PoolingBackward::make(
SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param,
const OperatorNodeConfig& config, const ExecutionPolicy& policy) {
const ExecutionPolicy& policy, const OperatorNodeConfig& config) {
intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2});
return i0.insert_single_output_opr<PoolingBackward>(
i0.node(), i1.node(), i2.node(), param, policy, config);
......
......@@ -26,8 +26,8 @@ MGE_WIN_DECLSPEC_FUC PoolingForward(
VarNode* src, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {},
const ExecutionPolicy& policy = {});
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
void init_output_static_infer_desc() override;
......@@ -47,7 +47,7 @@ MGE_WIN_DECLSPEC_FUC PoolingBackward(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param,
const OperatorNodeConfig& config = {}, const ExecutionPolicy& policy = {});
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
......
......@@ -15,7 +15,9 @@
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
......@@ -32,39 +34,24 @@ using namespace mgb;
namespace {
#if MGB_CUDA
#if MGB_ENABLE_FASTRUN
template <typename MgbOpr, int arith>
struct GraphMaker;
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 2> {
SymbolVar operator()(
const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], param, policy);
}
};
template <>
struct GraphMaker<opr::ConvolutionBackwardData, 2> {
struct GraphMaker<opr::Pooling, 1> {
SymbolVar operator()(
const std::array<cg::SymbolVar, 2>& inputs,
opr::ConvolutionBackwardData::Param& param,
opr::ConvolutionBackwardData::ExecutionPolicy& policy) {
return opr::ConvolutionBackwardData::make_deconv(
inputs[0], inputs[1], param, policy);
const std::array<cg::SymbolVar, 1>& inputs, opr::Pooling::Param& param,
opr::Pooling::ExecutionPolicy& policy) {
return opr::Pooling::make(inputs[0], param, policy);
}
};
template <>
struct GraphMaker<opr::Convolution3DBackwardData, 2> {
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 2> {
SymbolVar operator()(
const std::array<cg::SymbolVar, 2>& inputs,
opr::Convolution3DBackwardData::Param& param,
opr::Convolution3DBackwardData::ExecutionPolicy& policy) {
return opr::Convolution3DBackwardData::make_deconv(
inputs[0], inputs[1], param, policy);
const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], param, policy);
}
};
......@@ -98,6 +85,37 @@ struct GraphMaker<MgbOpr, 5> {
}
};
template <typename MgbOpr, int arith, typename dtype = dtype::Float32>
void test_execution_policy_shallow_copy(
std::array<TensorShape, arith> shapes, typename MgbOpr::Param param = {}) {
using Policy = typename MgbOpr::ExecutionPolicy;
Policy policy;
policy.strategy = Policy::Strategy::PROFILE;
auto cn = CompNode::load("cpu0");
auto graph0 = ComputingGraph::make(), graph1 = ComputingGraph::make();
std::array<cg::SymbolVar, arith> inputs0;
VarNodeArray inputs1;
for (size_t i = 0; i < arith; ++i) {
HostTensorND hi{cn, shapes[i], dtype()};
inputs0[i] = opr::ImmutableTensor::make(*graph0, hi);
inputs1.push_back(opr::ImmutableTensor::make(*graph1, hi).node());
}
GraphMaker<MgbOpr, arith> graph_maker;
auto opr0 = graph_maker(inputs0, param, policy).node()->owner_opr();
auto opr1 = serialization::copy_opr_shallow(*opr0, inputs1, OperatorNodeConfig{});
auto m0 = &(opr0->template cast_final<MgbOpr>());
auto m1 = &(opr1->template cast_final<MgbOpr>());
ASSERT_EQ(policy.strategy, m0->execution_policy().strategy);
ASSERT_EQ(policy.strategy, m1->execution_policy().strategy);
}
#if MGB_CUDA
#if MGB_ENABLE_FASTRUN
template <typename MgbOpr, int arith, typename dtype = dtype::Float32>
void test_fastrun_opr(
std::array<TensorShape, arith> inps0, std::array<TensorShape, arith> inps1,
......@@ -162,16 +180,24 @@ void test_fastrun_opr(
size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0;
ASSERT_EQ(cache_set_history.size(), nr_set_total);
}
#endif // MGB_ENABLE_FASTRUN
#endif // MGB_CUDA
} // anonymous namespace
#if MGB_CUDA
#if MGB_ENABLE_FASTRUN
TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) {
REQUIRE_GPU(1);
test_fastrun_opr<opr::Convolution, 2>(
{TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}},
{TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}});
test_fastrun_opr<opr::ConvolutionBackwardData, 2>(
{TensorShape{12, 4, 23, 29}, TensorShape{4, 5, 3, 2}},
{TensorShape{2, 4, 23, 29}, TensorShape{4, 5, 3, 2}});
test_fastrun_opr<opr::ConvolutionBackwardData, 3>(
{TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29},
TensorShape{12, 5, 25, 30}},
{TensorShape{4, 5, 3, 2}, TensorShape{2, 4, 23, 29},
TensorShape{2, 5, 25, 30}});
test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>(
{TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28},
......@@ -195,9 +221,11 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution3D) {
{TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}},
{TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}});
test_fastrun_opr<opr::Convolution3DBackwardData, 2>(
{TensorShape{14, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}},
{TensorShape{4, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}});
test_fastrun_opr<opr::Convolution3DBackwardData, 3>(
{TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16},
TensorShape{14, 5, 14, 14, 18}},
{TensorShape{5, 5, 3, 3, 3}, TensorShape{4, 5, 12, 12, 16},
TensorShape{4, 5, 14, 14, 18}});
test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>(
{TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18},
......@@ -295,6 +323,87 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) {
#endif // MGB_ENABLE_FASTRUN
#endif // MGB_CUDA
} // anonymous namespace
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution) {
test_execution_policy_shallow_copy<opr::Convolution, 2>(
{TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}});
test_execution_policy_shallow_copy<opr::ConvolutionBackwardData, 3>(
{TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29},
TensorShape{12, 5, 25, 30}});
test_execution_policy_shallow_copy<opr::ConvolutionBackwardFilter, 3>(
{TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28},
TensorShape{5, 4, 3, 2}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvBias) {
test_execution_policy_shallow_copy<opr::ConvBias, 3>(
{TensorShape{20, 16, 50, 50}, TensorShape{24, 16, 3, 3},
TensorShape{1, 24, 1, 1}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution3D) {
test_execution_policy_shallow_copy<opr::Convolution3D, 2>(
{TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}});
test_execution_policy_shallow_copy<opr::Convolution3DBackwardData, 3>(
{TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16},
TensorShape{14, 5, 14, 14, 18}});
test_execution_policy_shallow_copy<opr::Convolution3DBackwardFilter, 3>(
{TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18},
TensorShape{16, 16, 1, 1, 1}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyLocalShare) {
opr::LocalShare::Param local_share_param;
local_share_param.mode = opr::LocalShare::Param::Mode::CROSS_CORRELATION;
local_share_param.pad_h = local_share_param.pad_w = 1;
local_share_param.stride_h = local_share_param.stride_w = 1;
local_share_param.spatial_groups_h = local_share_param.spatial_groups_w = 2;
test_execution_policy_shallow_copy<opr::LocalShareForward, 2>(
{TensorShape{32, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}},
local_share_param);
test_execution_policy_shallow_copy<opr::LocalShareBackwardData, 3>(
{TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{32, 128, 24, 24},
TensorShape{32, 128, 24, 24}});
test_execution_policy_shallow_copy<opr::LocalShareBackwardFilter, 3>(
{TensorShape{12, 3, 36, 36}, TensorShape{12, 4, 35, 35},
TensorShape{3, 3, 3, 3, 3, 4}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyDeformableConv) {
test_execution_policy_shallow_copy<opr::DeformableConvForward, 4>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}});
test_execution_policy_shallow_copy<opr::DeformableConvBackwardData, 5>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18},
TensorShape{12, 6, 18, 18}});
test_execution_policy_shallow_copy<opr::DeformableConvBackwardFilter, 5>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18},
TensorShape{12, 6, 18, 18}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyMatrixMul) {
test_execution_policy_shallow_copy<opr::MatrixMul, 2>(
{TensorShape{10, 12}, TensorShape{12, 12}});
test_execution_policy_shallow_copy<opr::BatchedMatrixMul, 2>(
{TensorShape{12, 6, 8}, TensorShape{12, 8, 4}});
}
TEST(TestOprDNN, ExecutionPolicyShallowCopyPooling) {
test_execution_policy_shallow_copy<opr::Pooling, 1>({TensorShape{1, 20, 24, 24}});
test_execution_policy_shallow_copy<opr::PoolingBackward, 3>(
{TensorShape{1, 20, 24, 24}, TensorShape{1, 20, 12, 12},
TensorShape{1, 20, 12, 12}});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -135,7 +135,7 @@ TEST(TestOprDNN, PoolingExePolicy) {
Policy policy;
policy.strategy = strategy;
auto pooling = opr::PoolingForward::make(input, param, {}, policy);
auto pooling = opr::PoolingForward::make(input, param, policy);
auto loss0 = opr::reduce_sum_sqr(pooling, pooling.make_scalar(1));
auto grad = cg::grad(loss0, input, true, false);
......@@ -187,7 +187,7 @@ TEST(TestOprDNN, PoolingForwardFastrun) {
Policy policy;
policy.strategy = strategy;
auto pooling = opr::PoolingForward::make(input, param, {}, policy);
auto pooling = opr::PoolingForward::make(input, param, policy);
auto func = graph->compile({make_callback_copy(pooling, host_y)});
func->execute().wait();
......
......@@ -253,4 +253,11 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {};
__caller_OprRegShallowCopy##_cls##_ins; \
}
/*!
* \brief register opr serialization and shallow copy methods
*/
#define MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(_cls, _arity, _copy) \
MGB_SEREG_OPR(_cls, _arity) \
MGB_REG_OPR_SHALLOW_COPY(_cls, ::mgb::serialization::_copy<_cls>)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册