dnn.sereg.h 24.3 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/dnn/dnn.sereg.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
M
Megvii Engine Team 已提交
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12
 */

M
Megvii Engine Team 已提交
13
#include "megbrain/opr/dnn/adaptive_pooling.h"
14 15
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
16
#include "megbrain/opr/dnn/correlation.h"
M
Megvii Engine Team 已提交
17
#include "megbrain/opr/dnn/fake_quant.h"
18 19 20
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
M
Megvii Engine Team 已提交
21 22 23 24
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
25
#include "megbrain/opr/dnn/sliding_window_transpose.h"
M
Megvii Engine Team 已提交
26
#include "megbrain/opr/dnn/tqt.h"
27
#include "megbrain/serialization/sereg.h"
28 29
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
30 31 32 33

namespace mgb {

namespace serialization {
34 35 36
template <class MegDNNPooling = megdnn::Pooling>
struct MakePoolingCaller1 {
    template <typename Opr>
M
Megvii Engine Team 已提交
37 38
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
39
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
40
            const OperatorNodeConfig& config) {
41
        if (inputs.size() == 1) {
42
            return Opr::make(inputs[0], param, execution_policy, config).node();
43
        }
44 45 46 47 48 49 50
        return nullptr;
    }
};

template <class MegDNNROIALIGN = megdnn::ROIAlign>
struct MakeROIAlignCaller1 {
    template <typename Opr>
M
Megvii Engine Team 已提交
51 52 53
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
            const OperatorNodeConfig& config) {
54 55 56
        if (inputs.size() == 2) {
            return Opr::make(inputs[0], inputs[1], param, config).node();
        } else {
57 58
            return nullptr;
        }
59 60 61 62 63 64
    }
};

template <class MegDNNROIALIGN = megdnn::ROIAlignBackward>
struct MakeROIAlignCaller4 {
    template <typename Opr>
M
Megvii Engine Team 已提交
65 66 67
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNROIALIGN::Param& param,
            const OperatorNodeConfig& config) {
68
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
69
            return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
70 71
                    .node();
        } else {
72 73
            return nullptr;
        }
74 75 76 77 78 79
    }
};

template <class MegDNNPooling = megdnn::PoolingBackward>
struct MakePoolingBackwardCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
80 81
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
82
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
83
            const OperatorNodeConfig& config) {
84
        if (inputs.size() == 3) {
85 86 87 88
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
                    .node();
89
        }
90 91 92 93 94 95 96
        return nullptr;
    }
};

template <class MegDNNPooling = megdnn::AdaptivePoolingBackward>
struct MakeAdaptivePoolingBackwardCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
97 98 99
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param,
            const OperatorNodeConfig& config) {
100
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
101
            return Opr::make(inputs[0], inputs[1], inputs[2], inputs[3], param, config)
102
                    .node();
103
        }
104 105 106 107 108 109 110
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller2 {
    template <typename Opr>
M
Megvii Engine Team 已提交
111 112 113 114
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
115
        if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
116
            return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
117
                    .node();
118
        }
119 120 121 122 123 124 125
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
126 127 128 129
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
130
        if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
131 132 133
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
134
                    .node();
135
        }
136 137 138 139 140 141 142
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller4 {
    template <typename Opr>
M
Megvii Engine Team 已提交
143 144 145 146
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
147
        if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
148 149 150
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], inputs[3], param,
                           execution_policy, config)
151
                    .node();
152
        }
153 154 155 156 157 158 159
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCaller5 {
    template <typename Opr>
M
Megvii Engine Team 已提交
160 161 162 163
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
164
        if (inputs.size() == 5) {
M
Megvii Engine Team 已提交
165 166 167
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param,
                           execution_policy, config)
168
                    .node();
169
        }
170 171 172 173 174 175 176
        return nullptr;
    }
};

template <class MegDNNConv = megdnn::Convolution>
struct MakeConvCallerEmpty {
    template <typename Opr>
M
Megvii Engine Team 已提交
177 178 179
    static VarNode* make(
            const cg::VarNodeArray&, const typename MegDNNConv::Param&,
            const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
180 181 182 183
        return nullptr;
    }
};

M
Megvii Engine Team 已提交
184 185 186 187 188
template <
        class Opr, class Maker0, class MegDNNConv,
        class Maker1 = MakeConvCallerEmpty<MegDNNConv>,
        class Maker2 = MakeConvCallerEmpty<MegDNNConv>,
        typename ConvParam = megdnn::param::Convolution>
189 190 191 192
struct ConvLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<ConvParam>(opr.param());
M
Megvii Engine Team 已提交
193 194
        ctx.write_param<megdnn::param::ExecutionPolicy>(
                opr.execution_policy_transient());
195 196
    }

M
Megvii Engine Team 已提交
197 198 199 200 201 202
    static VarNode* make(
            const cg::VarNodeArray& inputs, const ConvParam& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
203
        if (!ret) {
M
Megvii Engine Team 已提交
204
            ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
205
        }
206
        if (!ret) {
M
Megvii Engine Team 已提交
207
            ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
M
Megvii Engine Team 已提交
208
        }
209 210 211 212
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
213 214 215
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
216
        auto param = ctx.read_param<ConvParam>();
M
Megvii Engine Team 已提交
217
        auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
218 219 220 221
        return make(inputs, param, execution_policy, config)->owner_opr();
    }
};

M
Megvii Engine Team 已提交
222
template <class Opr, class Maker0, typename PoolingParam = megdnn::param::Pooling>
223 224 225 226 227 228
struct PoolingLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<PoolingParam>(opr.param());
    }

M
Megvii Engine Team 已提交
229 230
    static VarNode* make(
            const cg::VarNodeArray& inputs, const PoolingParam& param,
231
            const megdnn::param::ExecutionPolicy& execution_policy,
M
Megvii Engine Team 已提交
232
            const OperatorNodeConfig& config) {
233 234
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
235 236 237 238
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
239 240 241
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
242
        auto param = ctx.read_param<PoolingParam>();
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        return make(inputs, param, {}, config)->owner_opr();
    }
};

template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign>
struct GeneralOprLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<GeneralOprParam>(opr.param());
    }

    static VarNode* make(
            const cg::VarNodeArray& inputs, const GeneralOprParam& param,
            const OperatorNodeConfig& config) {
        VarNode* ret = Maker0::template make<Opr>(inputs, param, config);
        mgb_assert(ret);
        return ret;
    }

    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
        auto param = ctx.read_param<GeneralOprParam>();
266 267 268 269 270 271 272
        return make(inputs, param, config)->owner_opr();
    }
};

template <>
struct OprMaker<opr::TQTBackward, 3> {
    using Param = opr::TQTBackward::Param;
M
Megvii Engine Team 已提交
273 274 275
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
276 277 278 279 280 281 282
        MGB_MARK_USED_VAR(graph);
        return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0]
                .node()
                ->owner_opr();
    }
};

M
Megvii Engine Team 已提交
283 284 285
template <>
struct OprMaker<opr::LSQBackward, 5> {
    using Param = opr::LSQBackward::Param;
M
Megvii Engine Team 已提交
286 287 288
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
M
Megvii Engine Team 已提交
289
        MGB_MARK_USED_VAR(graph);
M
Megvii Engine Team 已提交
290
        return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
M
Megvii Engine Team 已提交
291 292 293 294
                .node()
                ->owner_opr();
    }
};
295 296
template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
297
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
298 299 300
                  opr::AdaptivePoolingBackward,
                  MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>,
                  megdnn::param::AdaptivePooling> {};
301 302 303

template <>
struct OprLoadDumpImpl<opr::AdaptivePooling, 0>
304
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
305
                  opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>,
306 307 308 309
                  megdnn::param::AdaptivePooling> {};

template <>
struct OprLoadDumpImpl<opr::ROIAlign, 0>
310
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
311 312
                  opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>,
                  megdnn::param::ROIAlign> {};
313 314 315

template <>
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0>
316
        : public GeneralOprLoadDumpImpl<
M
Megvii Engine Team 已提交
317
                  opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>,
318 319 320 321
                  megdnn::param::ROIAlign> {};

template <>
struct OprLoadDumpImpl<opr::Pooling, 0>
M
Megvii Engine Team 已提交
322 323 324
        : public PoolingLoadDumpImpl<
                  opr::Pooling, MakePoolingCaller1<megdnn::Pooling>,
                  megdnn::param::Pooling> {};
325 326 327 328 329 330 331 332 333 334

template <>
struct OprLoadDumpImpl<opr::PoolingBackward, 0>
        : public PoolingLoadDumpImpl<
                  opr::PoolingBackward,
                  MakePoolingBackwardCaller3<megdnn::PoolingBackward>,
                  megdnn::param::Pooling> {};

template <>
struct OprLoadDumpImpl<opr::Convolution, 0>
M
Megvii Engine Team 已提交
335 336 337
        : public ConvLoadDumpImpl<
                  opr::Convolution, MakeConvCaller2<megdnn::Convolution>,
                  megdnn::Convolution> {};
338 339
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardData, 0>
M
Megvii Engine Team 已提交
340 341 342
        : public ConvLoadDumpImpl<
                  opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>,
                  megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {};
343 344
template <>
struct OprLoadDumpImpl<opr::ConvolutionBackwardFilter, 0>
M
Megvii Engine Team 已提交
345 346 347
        : public ConvLoadDumpImpl<
                  opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>,
                  megdnn::Convolution> {};
348 349 350

template <>
struct OprLoadDumpImpl<opr::Convolution3D, 0>
M
Megvii Engine Team 已提交
351 352 353 354 355
        : public ConvLoadDumpImpl<
                  opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>,
                  megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
356 357
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardData, 0>
M
Megvii Engine Team 已提交
358 359 360 361 362 363
        : public ConvLoadDumpImpl<
                  opr::Convolution3DBackwardData,
                  MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D,
                  MakeConvCaller3<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
364 365
template <>
struct OprLoadDumpImpl<opr::Convolution3DBackwardFilter, 0>
M
Megvii Engine Team 已提交
366 367 368 369 370 371
        : public ConvLoadDumpImpl<
                  opr::Convolution3DBackwardFilter,
                  MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  MakeConvCallerEmpty<megdnn::Convolution3D>,
                  megdnn::param::Convolution3D> {};
372 373
template <>
struct OprLoadDumpImpl<opr::ConvBiasForward, 0>
M
Megvii Engine Team 已提交
374 375 376 377
        : public ConvLoadDumpImpl<
                  opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>,
                  megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>,
                  MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {};
378 379
template <>
struct OprLoadDumpImpl<opr::BatchConvBiasForward, 0>
M
Megvii Engine Team 已提交
380 381 382 383 384 385 386
        : public ConvLoadDumpImpl<
                  opr::BatchConvBiasForward,
                  MakeConvCaller2<megdnn::BatchConvBiasForward>,
                  megdnn::BatchConvBiasForward,
                  MakeConvCaller3<megdnn::BatchConvBiasForward>,
                  MakeConvCaller4<megdnn::BatchConvBiasForward>,
                  megdnn::param::BatchConvBias> {};
387 388 389 390

template <>
struct OprMaker<opr::BatchNorm, 0> {
    using Param = opr::BatchNorm::Param;
M
Megvii Engine Team 已提交
391 392 393
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
394 395 396 397 398 399 400
        MGB_MARK_USED_VAR(graph);
        if (i.size() == 3) {
            return opr::BatchNorm::make(i[0], i[1], i[2], param, config)[0]
                    .node()
                    ->owner_opr();
        } else {
            mgb_assert(i.size() == 5);
M
Megvii Engine Team 已提交
401
            return opr::BatchNorm::make(i[0], i[1], i[2], i[3], i[4], param, config)[0]
402 403
                    .node()
                    ->owner_opr();
404
        }
405 406 407
    }
};

408
// OprMaker in MGB_SEREG_OPR only support unique output opr
409
template <>
410
struct OprMaker<opr::BatchNormBackward, 6> {
411
    using Param = opr::BatchNormBackward::Param;
M
Megvii Engine Team 已提交
412 413 414
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
415
        MGB_MARK_USED_VAR(graph);
M
Megvii Engine Team 已提交
416 417
        return opr::BatchNormBackward::make(
                       i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
418 419 420 421 422 423 424 425
                .node()
                ->owner_opr();
    }
};

template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
    template <typename Opr>
M
Megvii Engine Team 已提交
426 427 428 429
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
430
        if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
431
            return Opr::make(inputs[0], inputs[1], param, execution_policy, config)
432
                    .node();
433
        }
434 435 436 437 438 439
        return nullptr;
    }
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller3 {
    template <typename Opr>
M
Megvii Engine Team 已提交
440 441 442 443
    static VarNode* make(
            const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
444
        if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
445 446 447
            return Opr::make(
                           inputs[0], inputs[1], inputs[2], param, execution_policy,
                           config)
448
                    .node();
449
        }
450 451 452 453 454 455
        return nullptr;
    }
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCallerEmpty {
    template <typename Opr>
M
Megvii Engine Team 已提交
456 457 458
    static VarNode* make(
            const cg::VarNodeArray&, const typename MegDNNConv::Param&,
            const megdnn::param::ExecutionPolicy&, const OperatorNodeConfig&) {
459 460 461 462
        return nullptr;
    }
};

M
Megvii Engine Team 已提交
463 464 465 466 467
template <
        class Opr, class Maker0, class MegDNNConv,
        class Maker1 = MakeLocalShareCallerEmpty<MegDNNConv>,
        class Maker2 = MakeLocalShareCallerEmpty<MegDNNConv>,
        typename LocalShareParam = megdnn::param::LocalShare>
468 469 470 471 472 473 474
struct LocalShareLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<LocalShareParam>(opr.param());
        ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
    }

M
Megvii Engine Team 已提交
475 476 477 478 479 480
    static VarNode* make(
            const cg::VarNodeArray& inputs, const LocalShareParam& param,
            const megdnn::param::ExecutionPolicy& execution_policy,
            const OperatorNodeConfig& config) {
        VarNode* ret =
                Maker0::template make<Opr>(inputs, param, execution_policy, config);
481
        if (!ret) {
M
Megvii Engine Team 已提交
482
            ret = Maker1::template make<Opr>(inputs, param, execution_policy, config);
483
        }
484
        if (!ret) {
M
Megvii Engine Team 已提交
485
            ret = Maker2::template make<Opr>(inputs, param, execution_policy, config);
486
        }
487 488 489 490
        mgb_assert(ret);
        return ret;
    }

M
Megvii Engine Team 已提交
491 492 493
    static cg::OperatorNodeBase* load(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config) {
494
        auto param = ctx.read_param<LocalShareParam>();
M
Megvii Engine Team 已提交
495
        auto execution_policy = ctx.read_param<megdnn::param::ExecutionPolicy>();
496 497 498 499 500 501 502 503 504 505 506 507 508
        return make(inputs, param, execution_policy, config)->owner_opr();
    }
};

template <>
struct OprLoadDumpImpl<opr::LocalShare, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>,
                  megdnn::LocalShare> {};
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardData, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShareBackwardData,
M
Megvii Engine Team 已提交
509
                  MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
510 511 512 513
template <>
struct OprLoadDumpImpl<opr::LocalShareBackwardFilter, 0>
        : public LocalShareLoadDumpImpl<
                  opr::LocalShareBackwardFilter,
M
Megvii Engine Team 已提交
514
                  MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare> {};
515 516 517 518
template <>
struct OprLoadDumpImpl<opr::DeformableConvForward, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvForward,
M
Megvii Engine Team 已提交
519 520
                  MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> {
};
521 522 523 524 525 526 527 528 529 530 531 532
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardData, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvBackwardData,
                  MakeConvCaller5<megdnn::DeformableConvBackwardData>,
                  megdnn::Convolution> {};
template <>
struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0>
        : public ConvLoadDumpImpl<
                  opr::DeformableConvBackwardFilter,
                  MakeConvCaller5<megdnn::DeformableConvBackwardFilter>,
                  megdnn::Convolution> {};
533 534 535 536 537 538 539 540 541 542 543 544 545

template <typename Opr>
cg::OperatorNodeBase* opr_shallow_copy_conv(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    MGB_MARK_USED_VAR(ctx);
    auto&& opr = opr_.cast_final_safe<Opr>();
    return OprLoadDumpImpl<Opr, 0>::make(
                   inputs, opr.param(), opr.execution_policy_transient(), config)
            ->owner_opr();
}

546
}  // namespace serialization
547 548

namespace opr {
549 550 551
using ConvolutionV2 = Convolution;
using ConvolutionBackwardDataV2 = ConvolutionBackwardData;
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter;
552 553 554 555
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv);
556 557 558 559

MGB_SEREG_OPR(Images2Neibs, 1);
MGB_SEREG_OPR(Images2NeibsBackward, 2);

560 561 562
MGB_SEREG_OPR(SlidingWindowTranspose, 1);
MGB_SEREG_OPR(SlidingWindowTransposeBackward, 2);

563 564 565 566 567 568 569 570 571 572 573 574 575
using LocalV2 = Local;
using LocalBackwardDataV2 = LocalBackwardData;
using LocalBackwardFilterV2 = LocalBackwardFilter;
MGB_SEREG_OPR(LocalV2, 2);
MGB_SEREG_OPR(LocalBackwardDataV2, 3);
MGB_SEREG_OPR(LocalBackwardFilterV2, 3);

using GroupLocalV2 = GroupLocal;
using GroupLocalBackwardDataV2 = GroupLocalBackwardData;
using GroupLocalBackwardFilterV2 = GroupLocalBackwardFilter;
MGB_SEREG_OPR(GroupLocalV2, 2);
MGB_SEREG_OPR(GroupLocalBackwardDataV2, 3);
MGB_SEREG_OPR(GroupLocalBackwardFilterV2, 3);
576 577 578 579 580

MGB_SEREG_OPR(LRN, 1);
MGB_SEREG_OPR(LRNBackward, 3);
using PoolingV1 = Pooling;
using PoolingBackwardV1 = PoolingBackward;
581 582
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv);
583 584 585 586 587 588 589 590
using AdaptivePoolingV1 = AdaptivePooling;
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward;
MGB_SEREG_OPR(AdaptivePoolingV1, 2);
MGB_SEREG_OPR(AdaptivePoolingBackwardV1, 4);

MGB_SEREG_OPR(ROIPooling, 3);
MGB_SEREG_OPR(ROIPoolingBackward, 4);

591 592
using MaskConvolutionV2 = MaskConvolution;
MGB_SEREG_OPR(MaskConvolutionV2, 3);
593 594
MGB_SEREG_OPR(MaskPropagate, 1);

595 596 597 598
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        Convolution3DBackwardFilter, 0, opr_shallow_copy_conv);
599 600

using ConvBiasForwardV4 = ConvBiasForward;
601
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv);
602

603 604 605 606
using BatchNormV1 = BatchNorm;
using BatchNormBackwardV1 = BatchNormBackward;
MGB_SEREG_OPR(BatchNormV1, 0);
MGB_SEREG_OPR(BatchNormBackwardV1, 6);
607 608 609 610

using LocalShareForwardV1 = LocalShareForward;
using LocalShareBackwardDataV1 = LocalShareBackwardData;
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter;
611 612 613 614
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv);
615 616 617 618 619

using ROIAlignV1 = ROIAlign;
using ROIAlignBackwardV1 = ROIAlignBackward;
MGB_SEREG_OPR(ROIAlignV1, 2);
MGB_SEREG_OPR(ROIAlignBackwardV1, 4);
620 621 622
using DeformableConvForwardV1 = DeformableConvForward;
using DeformableConvBackwardDataV1 = DeformableConvBackwardData;
using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter;
623 624 625 626 627
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv);
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(
        DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv);
628

629 630 631 632
MGB_SEREG_OPR(CorrelationForward, 2);
MGB_SEREG_OPR(CorrelationBackwardData1, 3);
MGB_SEREG_OPR(CorrelationBackwardData2, 3);

633 634 635 636
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);

using BatchConvBiasForwardV1 = BatchConvBiasForward;
637
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv);
638 639 640 641
MGB_SEREG_OPR(FakeQuant, 3);
MGB_SEREG_OPR(FakeQuantBackward, 4);
MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
M
Megvii Engine Team 已提交
642 643
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
644 645 646
}  // namespace opr

}  // namespace mgb
647 648

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}