specializations.cpp 23.2 KB
Newer Older
1
/**
2
 * \file imperative/src/impl/ops/specialzations.cpp
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
M
Megvii Engine Team 已提交
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

// FIXME: split this file into separate files for each specialized op

#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
16 17
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
18
#include "megbrain/opr/dnn/adaptive_pooling.h"
M
Megvii Engine Team 已提交
19 20
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
21
#include "megbrain/opr/dnn/fake_quant.h"
M
Megvii Engine Team 已提交
22
#include "megbrain/opr/dnn/images2neibs.h"
23
#include "megbrain/opr/dnn/local.h"
M
Megvii Engine Team 已提交
24 25
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
26 27
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
M
Megvii Engine Team 已提交
28
#include "megbrain/opr/dnn/tqt.h"
29 30 31 32 33 34 35 36 37
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
38 39
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
40 41 42 43 44

#include "../op_trait.h"

namespace mgb::imperative {

M
Megvii Engine Team 已提交
45 46
namespace {
namespace dimshuffle {
47 48 49
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
    std::vector<int> pattern(node->param().pattern_len);
M
Megvii Engine Team 已提交
50
    for (size_t i = 0; i < node->param().pattern_len; ++i) {
51 52 53 54 55
        pattern[i] = node->param().pattern[i];
    }
    return Dimshuffle::make(pattern);
}

M
Megvii Engine Team 已提交
56
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
57
    auto&& ds = static_cast<const Dimshuffle&>(def);
58 59
    OperatorNodeConfig config{ds.make_name()};
    return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
60 61 62
}

OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
M
Megvii Engine Team 已提交
63 64 65 66 67
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace dimshuffle
}  // namespace
68

M
Megvii Engine Team 已提交
69 70 71
namespace {
namespace add_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
72 73 74 75 76 77
    auto&& add_axis = static_cast<const AddAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : add_axis.axis) {
        param.push_back(Desc::make_add(i));
    }
78 79
    OperatorNodeConfig config{add_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
80 81
}

M
Megvii Engine Team 已提交
82 83 84
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace add_axis
}  // namespace
85

M
Megvii Engine Team 已提交
86 87 88
namespace {
namespace remove_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
89 90 91 92 93 94
    auto&& remove_axis = static_cast<const RemoveAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : remove_axis.axis) {
        param.push_back(Desc::make_remove(i));
    }
95 96
    OperatorNodeConfig config{remove_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
97 98 99
}

OP_TRAIT_REG(RemoveAxis, RemoveAxis)
M
Megvii Engine Team 已提交
100 101 102 103
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace remove_axis
}  // namespace
104

M
Megvii Engine Team 已提交
105 106 107
namespace {
namespace top_k {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
108
    auto&& topk = static_cast<const TopK&>(def);
109 110
    OperatorNodeConfig config{topk.make_name()};
    return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
M
Megvii Engine Team 已提交
111 112
            .node()
            ->owner_opr();
113 114
}

M
Megvii Engine Team 已提交
115 116 117
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace top_k
}  // namespace
118

M
Megvii Engine Team 已提交
119 120 121
namespace {
namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
122
    auto&& reduce = static_cast<const Reduce&>(def);
123
    OperatorNodeConfig config{reduce.make_name()};
124
    if (inputs.size() > 1) {
125
        return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
126
    } else {
M
Megvii Engine Team 已提交
127 128
        return opr::Reduce::make(inputs[0], reduce.param(),
                                 (cg::VarNode*)nullptr, config);
129 130 131
    }
}

132 133 134 135 136 137
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Reduce>();
    return Reduce::make(node->param());
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
M
Megvii Engine Team 已提交
138 139 140 141 142
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace reduce
}  // namespace
143

M
Megvii Engine Team 已提交
144 145 146
namespace {
namespace adaptive_pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
147
    auto&& pool = static_cast<const AdaptivePooling&>(def);
148
    OperatorNodeConfig config{pool.make_name()};
M
Megvii Engine Team 已提交
149 150
    return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(),
                                      config);
151 152 153
}

OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
M
Megvii Engine Team 已提交
154 155 156 157
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace adaptive_pooling
}  // namespace
158

M
Megvii Engine Team 已提交
159 160 161
namespace {
namespace conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
162 163
    auto&& conv = static_cast<const ConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
164
    config.name(conv.make_name());
165
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
166 167
        return opr::ConvBias::make(inputs[0], inputs[1], conv.param(),
                                   conv.policy(), config);
168
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
169 170
        return opr::ConvBias::make(inputs[0], inputs[1], inputs[2],
                                   conv.param(), conv.policy(), config);
171
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
172 173
        return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3],
                                   conv.param(), conv.policy(), config);
174 175 176 177 178
    }
    mgb_assert(0);
}

OP_TRAIT_REG(ConvBias, ConvBias)
M
Megvii Engine Team 已提交
179 180 181 182
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace conv_bias
}  // namespace
183

M
Megvii Engine Team 已提交
184 185 186
namespace {
namespace batch_conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
187 188
    auto&& conv = static_cast<const BatchConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
189
    config.name(conv.make_name());
190
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
191 192
        return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(),
                                        conv.policy(), config);
193
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
194 195
        return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
                                        conv.param(), conv.policy(), config);
196
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
197 198 199
        return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2],
                                        inputs[3], conv.param(), conv.policy(),
                                        config);
200 201 202 203 204
    }
    mgb_assert(0);
}

OP_TRAIT_REG(BatchConvBias, BatchConvBias)
M
Megvii Engine Team 已提交
205 206 207 208
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batch_conv_bias
}  // namespace
209

M
Megvii Engine Team 已提交
210 211 212
namespace {
namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
213
    auto&& pool = static_cast<const Pooling&>(def);
214 215
    OperatorNodeConfig config{pool.make_name()};
    return opr::Pooling::make(inputs[0], pool.param(), config);
216
}
M
Megvii Engine Team 已提交
217 218 219
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace pooling
}  // namespace
220

M
Megvii Engine Team 已提交
221 222 223
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
224 225
    auto&& matmul = static_cast<const MatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
226
    OperatorNodeConfig config{matmul.make_name()};
227
    return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
228
                                matmul.policy(), config);
229 230
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
M
Megvii Engine Team 已提交
231 232 233 234
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace matrix_mul
}  // namespace
235

M
Megvii Engine Team 已提交
236 237 238
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
239 240
    auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
241
    OperatorNodeConfig config{matmul.make_name()};
242
    return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
243
                                       matmul.policy(), config);
244 245
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
M
Megvii Engine Team 已提交
246 247 248 249
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batched_matrix_mul
}  // namespace
250

M
Megvii Engine Team 已提交
251 252 253
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
254
    auto&& op = def.cast_final_safe<Dot>();
255
    mgb_assert(inputs.size() == 2);
256 257
    OperatorNodeConfig config{op.make_name()};
    return opr::Dot::make(inputs[0], inputs[1], config);
258
}
M
Megvii Engine Team 已提交
259 260 261
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace dot
}  // namespace
262

M
Megvii Engine Team 已提交
263 264 265
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
266
    auto&& argsort = static_cast<const Argsort&>(def);
267 268
    OperatorNodeConfig config{argsort.make_name()};
    return opr::Argsort::make(inputs[0], argsort.param(), config);
269
}
M
Megvii Engine Team 已提交
270 271 272
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argsort
}  // namespace
273

M
Megvii Engine Team 已提交
274 275 276
namespace {
namespace argmax {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
277
    auto&& argmax = static_cast<const Argmax&>(def);
278 279
    OperatorNodeConfig config{argmax.make_name()};
    return opr::Argmax::make(inputs[0], argmax.param(), config);
280
}
M
Megvii Engine Team 已提交
281 282 283
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmax
}  // namespace
284

M
Megvii Engine Team 已提交
285 286 287
namespace {
namespace argmin {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
288
    auto&& argmin = static_cast<const Argmin&>(def);
289 290
    OperatorNodeConfig config{argmin.make_name()};
    return opr::Argmin::make(inputs[0], argmin.param(), config);
291
}
M
Megvii Engine Team 已提交
292 293 294
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmin
}  // namespace
295

M
Megvii Engine Team 已提交
296 297 298
namespace {
namespace warp_perspective {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
299
    auto&& warp = static_cast<const WarpPerspective&>(def);
300
    OperatorNodeConfig config{warp.make_name()};
301
    if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
302 303
        return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
                                          warp.param(), config);
304 305
    } else {
        mgb_assert(inputs.size() == 4);
M
Megvii Engine Team 已提交
306 307
        return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2],
                                          inputs[3], warp.param(), config);
308 309 310
    }
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
M
Megvii Engine Team 已提交
311 312 313 314
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace warp_perspective
}  // namespace
315

M
Megvii Engine Team 已提交
316 317 318
namespace {
namespace group_local {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
319 320
    auto&& local = static_cast<const GroupLocal&>(def);
    mgb_assert(inputs.size() == 2);
321 322
    OperatorNodeConfig config{local.make_name()};
    return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
323 324
}
OP_TRAIT_REG(GroupLocal, GroupLocal)
M
Megvii Engine Team 已提交
325 326 327 328
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace group_local
}  // namespace
329

M
Megvii Engine Team 已提交
330 331 332
namespace {
namespace indexing_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
333 334
    auto&& op = static_cast<const IndexingOneHot&>(def);
    mgb_assert(inputs.size() == 2);
335 336
    OperatorNodeConfig config{op.make_name()};
    return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
337 338
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
M
Megvii Engine Team 已提交
339 340 341 342
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace indexing_one_hot
}  // namespace
343

M
Megvii Engine Team 已提交
344 345 346
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
347 348
    auto&& op = static_cast<const IndexingSetOneHot&>(def);
    mgb_assert(inputs.size() == 3);
349
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
350 351
    return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2],
                                        op.param(), config);
352 353
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
M
Megvii Engine Team 已提交
354 355 356 357
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace indexing_set_one_hot
}  // namespace
358

M
Megvii Engine Team 已提交
359 360 361
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
362 363
    auto&& op = static_cast<const TypeCvt&>(def);
    mgb_assert(inputs.size() == 1);
364 365
    OperatorNodeConfig config{op.make_name()};
    return opr::TypeCvt::make(inputs[0], op.dtype, config);
366
}
M
Megvii Engine Team 已提交
367 368 369
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace typecvt
}  // namespace
370

M
Megvii Engine Team 已提交
371 372 373
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
374 375
    auto&& op = static_cast<const Concat&>(def);
    cg::OperatorNodeConfig config{op.comp_node};
376
    config.name(op.make_name());
377 378
    return opr::Concat::make(inputs, op.axis, config);
}
M
Megvii Engine Team 已提交
379 380 381
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace concat
}  // namespace
382

M
Megvii Engine Team 已提交
383 384 385
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
386 387 388
    auto&& op = static_cast<const Copy&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
389
    config.name(op.make_name());
390 391
    return opr::Copy::make(inputs[0], config);
}
M
Megvii Engine Team 已提交
392 393 394
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace copy
}  // namespace
395

396 397
namespace { namespace assert_equal {
auto apply_on_var_node(
398 399 400 401 402 403 404 405
        const OpDef& def,
        const VarNodeArray& inputs) {
    auto&& op = def.cast_final<AssertEqual>();
    if (inputs.size() == 2) {
        return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
    } else {
        // workaround for MiniGraph, which only allow one opr in the graph
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
406 407
        return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2],
                                      op.param(), {});
408
    }
409
}
410

411
OP_TRAIT_REG(AssertEqual, AssertEqual)
M
Megvii Engine Team 已提交
412 413 414 415
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace assert_equal
}  // namespace
416

M
Megvii Engine Team 已提交
417 418 419
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
420 421
    auto&& op = static_cast<const ROIAlign&>(def);
    mgb_assert(inputs.size() == 2);
422
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
423 424 425
    auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
                        .node()
                        ->owner_opr();
426
    return {opr->output(0), opr->output(1)};
427 428
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
M
Megvii Engine Team 已提交
429 430 431 432
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace roi_align
}  // namespace
433

M
Megvii Engine Team 已提交
434 435 436
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
437 438 439
    auto&& op = static_cast<const Correlation&>(def);
    mgb_assert(inputs.size() == 2);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
440
    return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
441 442
}
OP_TRAIT_REG(Correlation, Correlation)
M
Megvii Engine Team 已提交
443 444 445 446
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace correlation
}  // namespace
447

448
#if MGB_CUDA
M
Megvii Engine Team 已提交
449 450 451
namespace {
namespace nvof {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
452 453
    auto&& op = static_cast<const NvOf&>(def);
    mgb_assert(inputs.size() == 1);
454 455
    OperatorNodeConfig config{op.make_name()};
    return opr::NvOf::make(inputs[0], op.param(), config);
456
}
M
Megvii Engine Team 已提交
457 458 459
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace nvof
}  // namespace
460 461
#endif

M
Megvii Engine Team 已提交
462 463 464
namespace {
namespace linspace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
465 466 467
    auto&& op = static_cast<const Linspace&>(def);
    mgb_assert(inputs.size() == 3);
    cg::OperatorNodeConfig config{op.comp_node};
468
    config.name(op.make_name());
M
Megvii Engine Team 已提交
469 470
    return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(),
                               config);
471 472
}
OP_TRAIT_REG(Linspace, Linspace)
M
Megvii Engine Team 已提交
473 474 475 476
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace linspace
}  // namespace
477

M
Megvii Engine Team 已提交
478 479 480
namespace {
namespace eye {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
481 482 483
    auto&& op = static_cast<const Eye&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
484
    config.name(op.make_name());
485 486 487
    opr::Eye::Param param{op.k, op.dtype.enumv()};
    return opr::Eye::make(inputs[0], param, config);
}
M
Megvii Engine Team 已提交
488 489 490
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace eye
}  // namespace
491

M
Megvii Engine Team 已提交
492 493 494
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
495 496
    auto&& op = static_cast<const ROIPooling&>(def);
    mgb_assert(inputs.size() == 3);
497
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
498 499 500 501
    auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2],
                                      op.param(), config)
                        .node()
                        ->owner_opr();
502
    return {opr->output(0), opr->output(1)};
503 504
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
M
Megvii Engine Team 已提交
505 506 507 508
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace roi_pooling
}  // namespace
509

M
Megvii Engine Team 已提交
510 511 512
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
513 514
    auto&& op = static_cast<const Remap&>(def);
    mgb_assert(inputs.size() == 2);
515 516
    OperatorNodeConfig config{op.make_name()};
    return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
517
}
M
Megvii Engine Team 已提交
518 519 520
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace remap
}  // namespace
521 522 523

namespace {
auto get_index(
M
Megvii Engine Team 已提交
524 525
        const VarNodeArray& inputs, size_t vidx,
        const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
526 527
    size_t length = mask.size();
    opr::Subtensor::IndexDesc ret(length);
M
Megvii Engine Team 已提交
528
    for (size_t i = 0; i < length; ++i) {
529 530 531 532 533 534
        auto&& [axis, begin, end, step, idx] = mask[i];
        ret[i].axis = axis;
        if (idx) {
            ret[i].idx = inputs[vidx++];
        } else {
            mgb_assert(begin || end || step);
M
Megvii Engine Team 已提交
535 536 537 538 539 540
            if (begin)
                ret[i].begin = inputs[vidx++];
            if (end)
                ret[i].end = inputs[vidx++];
            if (step)
                ret[i].step = inputs[vidx++];
541 542 543 544 545 546 547 548
        }
    }
    mgb_assert(vidx == inputs.size());
    return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]

M
Megvii Engine Team 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT)                                    \
    namespace NAME##_impl {                                                    \
        auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
            auto&& op = static_cast<const NAME&>(def);                         \
            OperatorNodeConfig config{op.make_name()};                         \
            return opr::NAME::make(IN##NR_INPUT,                               \
                                   get_index(inputs, NR_INPUT, op.items),      \
                                   config);                                    \
        }                                                                      \
        OP_TRAIT_REG(NAME, NAME)                                               \
                .apply_on_var_node(apply_on_var_node)                          \
                .fallback();                                                   \
    }
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578

FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)

#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
M
Megvii Engine Team 已提交
579
}  // anonymous namespace
580

M
Megvii Engine Team 已提交
581 582 583
namespace {
namespace fake_quant {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
584 585
    auto&& op = static_cast<const FakeQuant&>(def);
    mgb_assert(inputs.size() == 3);
586
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
587 588
    return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(),
                                config);
589 590
}
OP_TRAIT_REG(FakeQuant, FakeQuant)
M
Megvii Engine Team 已提交
591 592 593 594
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace fake_quant
}  // namespace
595

M
Megvii Engine Team 已提交
596 597 598
namespace {
namespace tqt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
M
Megvii Engine Team 已提交
599 600
    auto&& op = static_cast<const TQT&>(def);
    mgb_assert(inputs.size() == 2);
601 602
    OperatorNodeConfig config{op.make_name()};
    return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
M
Megvii Engine Team 已提交
603
}
M
Megvii Engine Team 已提交
604 605 606
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace tqt
}  // namespace
607

M
Megvii Engine Team 已提交
608 609 610
namespace {
namespace elemwise_multi_type {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
611 612
    auto&& op = static_cast<const ElemwiseMultiType&>(def);
    OperatorNodeConfig config{op.dtype};
613
    config.name(op.make_name());
614 615 616
    return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
M
Megvii Engine Team 已提交
617 618 619 620
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace elemwise_multi_type
}  // namespace
621

M
Megvii Engine Team 已提交
622 623 624
namespace {
namespace svd {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
625 626
    auto&& op = static_cast<const SVD&>(def);
    mgb_assert(inputs.size() == 1);
627 628
    OperatorNodeConfig config{op.make_name()};
    return opr::SVD::make(inputs[0], op.param(), config)[0]
M
Megvii Engine Team 已提交
629 630 631
            .node()
            ->owner_opr()
            ->usable_output();
632
}
M
Megvii Engine Team 已提交
633 634 635
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace svd
}  // namespace
636

M
Megvii Engine Team 已提交
637 638 639
namespace {
namespace images2neibs {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
640 641 642 643 644
    auto&& op = static_cast<const Images2Neibs&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Images2Neibs::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
M
Megvii Engine Team 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace images2neibs
}  // namespace

namespace {
namespace lsq {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LSQ&>(def);
    mgb_assert(inputs.size() == 4);
    OperatorNodeConfig config{op.make_name()};
    return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3],
                          op.param(), config);
}
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace lsq
}  // namespace
662

663 664 665 666 667 668 669 670 671 672 673 674 675 676
namespace { namespace sliding_window_transpose {
auto apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
    auto&& op = static_cast<const SlidingWindowTranspose&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
    .apply_on_var_node(apply_on_var_node)
    .fallback();
}} // sliding_window_transpose

} // namespace mgb::imperative