dnn.sereg.h 27.1 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "megbrain/opr/dnn/adaptive_pooling.h"
2 3
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
4
#include "megbrain/opr/dnn/correlation.h"
M
Megvii Engine Team 已提交
5
#include "megbrain/opr/dnn/fake_quant.h"
6
#include "megbrain/opr/dnn/images2neibs.h"
7
#include "megbrain/opr/dnn/layer_norm.h"
8 9
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
M
Megvii Engine Team 已提交
10 11
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
12
#include "megbrain/opr/dnn/rnn.h"
M
Megvii Engine Team 已提交
13 14
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
15
#include "megbrain/opr/dnn/sliding_window_transpose.h"
16
#include "megbrain/opr/dnn/softmax.h"
M
Megvii Engine Team 已提交
17
#include "megbrain/opr/dnn/tqt.h"
18
#include "megbrain/serialization/sereg.h"
19 20
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
21 22 23 24

namespace mgb {

namespace serialization {
25 26 27
template <class MegDNNPooling = megdnn::Pooling>
struct MakePoolingCaller1 {
    template <typename Opr>
M
Megvii Engine Team 已提交
28 29
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
30
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
31
            const OperatorNodeConfig& config) {
32
        if (inputs.size() == 1) {
33
            return Opr::make(inputs[0], param, execution_policy, config).node();
34
        }
35 36 37 38 39 40 41
        return nullptr;
    }
};

template <class MegDNNROIALIGN = megdnn::ROIAlign>
struct MakeROIAlignCaller1 {
    template <typename Opr>
M
Megvii Engine Team 已提交
42 43 44
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
            const OperatorNodeConfig& config) {
45 46 47
        if (inputs.size() == 2) {
            return Opr::make(inputs[0], inputs[1], param, config).node();
        } else {
48 49
            return nullptr;
        }
50 51 52 53 54 55
    }
};

template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
struct MakeROIAlignCaller4 {
    template <typename Opr>
M
Megvii Engine Team 已提交
56 57 58
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
            const OperatorNodeConfig& config) {
59
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
60
            return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
61 62
                    .node();
        } else {
63 64
            return nullptr;
        }
65 66 67 68 69 70
    }
};

template <class MegDNNPooling = megdnn::PoolingBackward>
struct MakePoolingBackwardCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
71 72
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
73
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
74
            const OperatorNodeConfig& config) {
75
        if (inputs.size() == 3) {
76 77 78 79
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
                    .node();
80
        }
81 82 83 84 85 86 87
        return nullptr;
    }
};

template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
struct MakeAdaptivePoolingBackwardCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
88 89 90
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
            const OperatorNodeConfig& config) {
91
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
92
            return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
93
                    .node();
94
        }
95 96 97 98 99 100 101
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
    template <typename Opr>
M
Megvii Engine Team 已提交
102 103 104 105
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
106
        if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
107
            return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
108
                    .node();
109
        }
110 111 112 113 114 115 116
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
117 118 119 120
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
121
        if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
122 123 124
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
125
                    .node();
126
        }
127 128 129 130 131 132 133
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
    template <typename Opr>
M
Megvii Engine Team 已提交
134 135 136 137
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
138
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
139 140 141
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], inputs[3], param,
                           execution_policy, config)
142
                    .node();
143
        }
144 145 146 147 148 149 150
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
    template <typename Opr>
M
Megvii Engine Team 已提交
151 152 153 154
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
155
        if (inputs.size() == 5) {
M
Megvii Engine Team 已提交
156 157 158
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param,
                           execution_policy, config)
159
                    .node();
160
        }
161 162 163 164 165 166 167
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCallerEmpty {
    template <typename Opr>
M
Megvii Engine Team 已提交
168 169 170
    static VarNode* make(
            const cg::VarNodeArray&, const typename MegDNNConv::Param&,
            const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
171 172 173 174
        return nullptr;
    }
};

M
Megvii Engine Team 已提交
175 176 177 178 179
template <
        class Opr, class Maker0, class MegDNNConv,
        class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
        class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
        typename ConvParam = megdnn::param::Convolution>
180 181 182 183
struct ConvLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<ConvParam>(opr.param());
M
Megvii Engine Team 已提交
184 185
        ctx.write_param<megdnn::param::ExecutionPolicy>(
                opr.execution_policy_transient());
186 187
    }

M
Megvii Engine Team 已提交
188 189 190 191 192 193
    static VarNode* make(
            const cg::VarNodeArray& inputs, const ConvParam& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
194
        if (!ret) {
M
Megvii Engine Team 已提交
195
            ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
196
        }
197
        if (!ret) {
M
Megvii Engine Team 已提交
198
            ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
M
Megvii Engine Team 已提交
199
        }
200 201 202 203
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
204 205 206
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
207
        auto param = ctx.read_param<ConvParam>();
M
Megvii Engine Team 已提交
208
        auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
209 210 211 212
        return make(inputs, param, execution_policy, config)->owner_opr();
    }
};

M
Megvii Engine Team 已提交
213
template <class Opr, class Maker0, typename PoolingParam = megdnn::param::Pooling>
214 215 216 217 218 219
struct PoolingLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<PoolingParam>(opr.param());
    }

M
Megvii Engine Team 已提交
220 221
    static VarNode* make(
            const cg::VarNodeArray& inputs, const PoolingParam& param,
222
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
223
            const OperatorNodeConfig& config) {
224 225
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
226 227 228 229
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
230 231 232
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
233
        auto param = ctx.read_param<PoolingParam>();
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        return make(inputs, param, {}, config)->owner_opr();
    }
};

template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign>
struct GeneralOprLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<GeneralOprParam>(opr.param());
    }

    static VarNode* make(
            const cg::VarNodeArray& inputs, const GeneralOprParam& param,
            const OperatorNodeConfig& config) {
        VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
        mgb_assert(ret);
        return ret;
    }

    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
        auto param = ctx.read_param<GeneralOprParam>();
257 258 259 260 261 262 263
        return make(inputs, param, config)->owner_opr();
    }
};

template <>
struct OprMaker<opr::TQTBackward, 3> {
    using Param = opr::TQTBackward::Param;
M
Megvii Engine Team 已提交
264 265 266
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
267 268 269 270 271 272 273
        MGB_MARK_USED_VAR(graph);
        return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
                .node()
                ->owner_opr();
    }
};

M
Megvii Engine Team 已提交
274 275 276
template <>
struct OprMaker<opr::LSQBackward, 5> {
    using Param = opr::LSQBackward::Param;
M
Megvii Engine Team 已提交
277 278 279
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
M
Megvii Engine Team 已提交
280
        MGB_MARK_USED_VAR(graph);
M
Megvii Engine Team 已提交
281
        return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
M
Megvii Engine Team 已提交
282 283 284 285
                .node()
                ->owner_opr();
    }
};
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315

template <>
struct OprMaker<opr::RNNBackward, 7> {
    using Param = opr::RNNBackward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::RNNBackward::make(
                       i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0]
                .node()
                ->owner_opr();
    }
};

template <>
struct OprMaker<opr::LSTMBackward, 9> {
    using Param = opr::LSTMBackward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::LSTMBackward::make(
                       i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param,
                       config)[0]
                .node()
                ->owner_opr();
    }
};

316 317 318 319 320 321 322 323 324 325 326 327 328
template <>
struct OprMaker<opr::SoftmaxBackward, 2> {
    using Param = opr::SoftmaxBackward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::SoftmaxBackward::make(i[0], i[1], param, config)
                .node()
                ->owner_opr();
    }
};

329 330
template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
331
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
332 333 334
                  opr::AdaptivePoolingBackward,
                  MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
                  megdnn::param::AdaptivePooling> {};
335 336 337

template <>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
338
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
339
                  opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>,
340 341 342 343
                  megdnn::param::AdaptivePooling> {};

template <>
struct OprLoadDumpImpl<opr::ROIAlign, 0>
344
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
345 346
                  opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>,
                  megdnn::param::ROIAlign> {};
347 348 349

template <>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
350
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
351
                  opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
352 353 354 355
                  megdnn::param::ROIAlign> {};

template <>
struct OprLoadDumpImpl<opr::Pooling, 0>
M
Megvii Engine Team 已提交
356 357 358
        : public PoolingLoadDumpImpl<
                  opr::Pooling, MakePoolingCaller1<megdnn::Pooling>,
                  megdnn::param::Pooling> {};
359 360 361 362 363 364 365 366 367 368

template <>
struct OprLoadDumpImpl<opr::PoolingBackward, 0>
        : public PoolingLoadDumpImpl<
                  opr::PoolingBackward,
                  MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
                  megdnn::param::Pooling> {};

template <>
struct OprLoadDumpImpl<opr::Convolution, 0>
M
Megvii Engine Team 已提交
369 370 371
        : public ConvLoadDumpImpl<
                  opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
                  megdnn::Convolution> {};
372 373
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>
M
Megvii Engine Team 已提交
374 375 376
        : public ConvLoadDumpImpl<
                  opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
                  megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
377 378
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>
M
Megvii Engine Team 已提交
379 380 381
        : public ConvLoadDumpImpl<
                  opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
                  megdnn::Convolution> {};
382 383 384

template <>
struct OprLoadDumpImpl<opr::Convolution3D, 0>
M
Megvii Engine Team 已提交
385 386 387 388 389
        : public ConvLoadDumpImpl<
                  opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
                  megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
390 391
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>
M
Megvii Engine Team 已提交
392 393 394 395 396 397
        : public ConvLoadDumpImpl<
                  opr::Convolution3DBackwardData,
                  MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
                  MakeConvCaller3<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
398 399
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
M
Megvii Engine Team 已提交
400 401 402 403 404 405
        : public ConvLoadDumpImpl<
                  opr::Convolution3DBackwardFilter,
                  MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
406 407
template <>
struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
M
Megvii Engine Team 已提交
408 409 410 411
        : public ConvLoadDumpImpl<
                  opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
                  megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
                  MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
412 413
template <>
struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
M
Megvii Engine Team 已提交
414 415 416 417 418 419 420
        : public ConvLoadDumpImpl<
                  opr::BatchConvBiasForward,
                  MakeConvCaller2<megdnn::BatchConvBiasForward>,
                  megdnn::BatchConvBiasForward,
                  MakeConvCaller3<megdnn::BatchConvBiasForward>,
                  MakeConvCaller4<megdnn::BatchConvBiasForward>,
                  megdnn::param::BatchConvBias> {};
421 422 423 424

template <>
struct OprMaker<opr::BatchNorm, 0> {
    using Param = opr::BatchNorm::Param;
M
Megvii Engine Team 已提交
425 426 427
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
428 429 430 431 432 433 434
        MGB_MARK_USED_VAR(graph);
        if (i.size() == 3) {
            return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0]
                    .node()
                    ->owner_opr();
        } else {
            mgb_assert(i.size() == 5);
M
Megvii Engine Team 已提交
435
            return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
436 437
                    .node()
                    ->owner_opr();
438
        }
439 440 441
    }
};

442
// OprMaker in MGB_SEREG_OPR only support unique output opr
443
template <>
444
struct OprMaker<opr::BatchNormBackward, 6> {
445
    using Param = opr::BatchNormBackward::Param;
M
Megvii Engine Team 已提交
446 447 448
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
449
        MGB_MARK_USED_VAR(graph);
M
Megvii Engine Team 已提交
450 451
        return opr::BatchNormBackward::make(
                       i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
452 453 454 455 456
                .node()
                ->owner_opr();
    }
};

457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
template <>
struct OprMaker<opr::LayerNorm, 0> {
    using Param = opr::LayerNorm::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        if (i.size() == 3) {
            return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0]
                    .node()
                    ->owner_opr();
        } else {
            mgb_assert(i.size() == 1);
            return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr();
        }
    }
};

// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::LayerNormBackward, 0> {
    using Param = opr::LayerNormBackward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        if (i.size() == 5) {
            return opr::LayerNormBackward::make(
                           i[0], i[1], i[2], i[3], i[4], param, config)[0]
                    .node()
                    ->owner_opr();
        } else {
            mgb_assert(i.size() == 4);
            return opr::LayerNormBackward::make(
                           i[0], i[1], i[2], i[3], param, config)[0]
                    .node()
                    ->owner_opr();
        }
    }
};

498 499 500
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
    template <typename Opr>
M
Megvii Engine Team 已提交
501 502 503 504
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
505
        if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
506
            return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
507
                    .node();
508
        }
509 510 511 512 513 514
        return nullptr;
    }
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
515 516 517 518
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
519
        if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
520 521 522
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
523
                    .node();
524
        }
525 526 527 528 529 530
        return nullptr;
    }
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCallerEmpty {
    template <typename Opr>
M
Megvii Engine Team 已提交
531 532 533
    static VarNode* make(
            const cg::VarNodeArray&, const typename MegDNNConv::Param&,
            const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
534 535 536 537
        return nullptr;
    }
};

M
Megvii Engine Team 已提交
538 539 540 541 542
template <
        class Opr, class Maker0, class MegDNNConv,
        class Maker1 = MakeLocalShareCallerEmpty<MegDNNConv>,
        class Maker2 = MakeLocalShareCallerEmpty<MegDNNConv>,
        typename LocalShareParam = megdnn::param::LocalShare>
543 544 545 546 547 548 549
struct LocalShareLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<LocalShareParam>(opr.param());
        ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
    }

M
Megvii Engine Team 已提交
550 551 552 553 554 555
    static VarNode* make(
            const cg::VarNodeArray& inputs, const LocalShareParam& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
556
        if (!ret) {
M
Megvii Engine Team 已提交
557
            ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
558
        }
559
        if (!ret) {
M
Megvii Engine Team 已提交
560
            ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
561
        }
562 563 564 565
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
566 567 568
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
569
        auto param = ctx.read_param<LocalShareParam>();
M
Megvii Engine Team 已提交
570
        auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
571 572 573 574 575 576 577 578 579 580 581 582 583
        return make(inputs, param, execution_policy, config)->owner_opr();
    }
};

template <>
struct OprLoadDumpImpl<opr::LocalShare, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
                  megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShareBackwardData,
M
Megvii Engine Team 已提交
584
                  MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
585 586 587 588
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShareBackwardFilter,
M
Megvii Engine Team 已提交
589
                  MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
590 591 592 593
template <>
struct OprLoadDumpImpl<opr::DeformableConvForward, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvForward,
M
Megvii Engine Team 已提交
594 595
                  MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
};
596 597 598 599 600 601 602 603 604 605 606 607
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvBackwardData,
                  MakeConvCaller5<megdnn::DeformableConvBackwardData>,
                  megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvBackwardFilter,
                  MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
                  megdnn::Convolution> {};
608 609 610 611 612 613 614 615 616 617 618 619 620

template <typename Opr>
cg::OperatorNodeBase* opr_shallow_copy_conv(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    MGB_MARK_USED_VAR(ctx);
    auto&& opr = opr_.cast_final_safe<Opr>();
    return OprLoadDumpImpl<Opr, 0>::make(
                   inputs, opr.param(), opr.execution_policy_transient(), config)
            ->owner_opr();
}

621
}  // namespace serialization
622 623

namespace opr {
624 625 626
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
627 628 629 630
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv);
631 632 633 634

MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);

635 636 637
MGB_SEREG_OPR(SlidingWindowTranspose, 1);
MGB_SEREG_OPR(SlidingWindowTransposeBackward, 2);

638 639 640 641 642 643 644 645 646 647 648 649 650
using LocalV2 = Local;
using LocalBackwardDataV2 = LocalBackwardData;
using LocalBackwardFilterV2 = LocalBackwardFilter;
MGB_SEREG_OPR(LocalV2, 2);
MGB_SEREG_OPR(LocalBackwardDataV2, 3);
MGB_SEREG_OPR(LocalBackwardFilterV2, 3);

using GroupLocalV2 = GroupLocal;
using GroupLocalBackwardDataV2 = GroupLocalBackwardData;
using GroupLocalBackwardFilterV2 = GroupLocalBackwardFilter;
MGB_SEREG_OPR(GroupLocalV2, 2);
MGB_SEREG_OPR(GroupLocalBackwardDataV2, 3);
MGB_SEREG_OPR(GroupLocalBackwardFilterV2, 3);
651 652 653 654 655

MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
656 657
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv);
658 659 660 661 662 663 664 665
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);

MGB_SEREG_OPR(ROIPooling, 3);
MGB_SEREG_OPR(ROIPoolingBackward, 4);

666 667
using MaskConvolutionV2 = MaskConvolution;
MGB_SEREG_OPR(MaskConvolutionV2, 3);
668 669
MGB_SEREG_OPR(MaskPropagate, 1);

670 671 672 673
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        Convolution3DBackwardFilter, 0, opr_shallow_copy_conv);
674 675

using ConvBiasForwardV4 = ConvBiasForward;
676
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv);
677

678 679 680 681
using BatchNormV1 = BatchNorm;
using BatchNormBackwardV1 = BatchNormBackward;
MGB_SEREG_OPR(BatchNormV1, 0);
MGB_SEREG_OPR(BatchNormBackwardV1, 6);
682 683 684 685

using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
686 687 688 689
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv);
690 691 692 693 694

using ROIAlignV1 = ROIAlign;
using ROIAlignBackwardV1 = ROIAlignBackward;
MGB_SEREG_OPR(ROIAlignV1, 2);
MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
695 696 697
using DeformableConvForwardV1 = DeformableConvForward;
using DeformableConvBackwardDataV1 = DeformableConvBackwardData;
using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter;
698 699 700 701 702
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv);
703

704 705 706 707
MGB_SEREG_OPR(CorrelationForward, 2);
MGB_SEREG_OPR(CorrelationBackwardData1, 3);
MGB_SEREG_OPR(CorrelationBackwardData2, 3);

708 709 710 711
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);

using BatchConvBiasForwardV1 = BatchConvBiasForward;
712
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv);
713 714 715 716
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
M
Megvii Engine Team 已提交
717 718
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
719 720
MGB_SEREG_OPR(LayerNorm, 0);
MGB_SEREG_OPR(LayerNormBackward, 0);
721 722 723 724
MGB_SEREG_OPR(RNNForward, 3);
MGB_SEREG_OPR(RNNBackward, 7);
MGB_SEREG_OPR(LSTMForward, 4);
MGB_SEREG_OPR(LSTMBackward, 9);
725 726
MGB_SEREG_OPR(Softmax, 1);
MGB_SEREG_OPR(SoftmaxBackward, 2);
727 728 729
}  // namespace opr

}  // namespace mgb
730 731

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}