specializations.cpp 23.0 KB
Newer Older
1
/**
2
 * \file imperative/src/impl/ops/specialzations.cpp
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
M
Megvii Engine Team 已提交
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

// FIXME: split this file into separate files for each specialized op

#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
16 17
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
18
#include "megbrain/opr/dnn/adaptive_pooling.h"
M
Megvii Engine Team 已提交
19 20
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
21
#include "megbrain/opr/dnn/fake_quant.h"
M
Megvii Engine Team 已提交
22
#include "megbrain/opr/dnn/images2neibs.h"
23
#include "megbrain/opr/dnn/layer_norm.h"
24
#include "megbrain/opr/dnn/local.h"
25
#include "megbrain/opr/dnn/lrn.h"
M
Megvii Engine Team 已提交
26 27
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
28 29
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
M
Megvii Engine Team 已提交
30
#include "megbrain/opr/dnn/sliding_window_transpose.h"
M
Megvii Engine Team 已提交
31
#include "megbrain/opr/dnn/tqt.h"
32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

#include "../op_trait.h"

namespace mgb::imperative {

M
Megvii Engine Team 已提交
46 47
namespace {
namespace dimshuffle {
48 49 50
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
    std::vector<int> pattern(node->param().pattern_len);
M
Megvii Engine Team 已提交
51
    for (size_t i = 0; i < node->param().pattern_len; ++i) {
52 53 54 55 56
        pattern[i] = node->param().pattern[i];
    }
    return Dimshuffle::make(pattern);
}

M
Megvii Engine Team 已提交
57
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
58
    auto&& ds = static_cast<const Dimshuffle&>(def);
59 60
    OperatorNodeConfig config{ds.make_name()};
    return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
61 62 63
}

OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
M
Megvii Engine Team 已提交
64 65 66 67 68
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace dimshuffle
}  // namespace
69

M
Megvii Engine Team 已提交
70 71 72
namespace {
namespace add_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
73 74 75 76 77 78
    auto&& add_axis = static_cast<const AddAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : add_axis.axis) {
        param.push_back(Desc::make_add(i));
    }
79 80
    OperatorNodeConfig config{add_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
81 82
}

M
Megvii Engine Team 已提交
83 84 85
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace add_axis
}  // namespace
86

M
Megvii Engine Team 已提交
87 88 89
namespace {
namespace remove_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
90 91 92 93 94 95
    auto&& remove_axis = static_cast<const RemoveAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : remove_axis.axis) {
        param.push_back(Desc::make_remove(i));
    }
96 97
    OperatorNodeConfig config{remove_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
98 99
}

M
Megvii Engine Team 已提交
100
OP_TRAIT_REG(RemoveAxis, RemoveAxis).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
101 102
}  // namespace remove_axis
}  // namespace
103

M
Megvii Engine Team 已提交
104 105 106
namespace {
namespace top_k {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
107
    auto&& topk = static_cast<const TopK&>(def);
108 109
    OperatorNodeConfig config{topk.make_name()};
    return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
M
Megvii Engine Team 已提交
110 111
            .node()
            ->owner_opr();
112 113
}

M
Megvii Engine Team 已提交
114 115 116
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace top_k
}  // namespace
117

M
Megvii Engine Team 已提交
118 119 120
namespace {
namespace adaptive_pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
121
    auto&& pool = static_cast<const AdaptivePooling&>(def);
122
    OperatorNodeConfig config{pool.make_name()};
M
Megvii Engine Team 已提交
123
    return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
124 125 126
}

OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
M
Megvii Engine Team 已提交
127 128 129 130
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace adaptive_pooling
}  // namespace
131

M
Megvii Engine Team 已提交
132 133 134
namespace {
namespace conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
135 136
    auto&& conv = static_cast<const ConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
137
    config.name(conv.make_name());
138
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
139 140
        return opr::ConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
141
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
142 143
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
144
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
145 146 147
        return opr::ConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
148 149 150 151
    }
    mgb_assert(0);
}

M
Megvii Engine Team 已提交
152
OP_TRAIT_REG(ConvBias, ConvBias).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
153 154
}  // namespace conv_bias
}  // namespace
155

M
Megvii Engine Team 已提交
156 157 158
namespace {
namespace batch_conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
159 160
    auto&& conv = static_cast<const BatchConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
161
    config.name(conv.make_name());
162
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
163 164
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
165
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
166 167
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
168
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
169 170 171
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
172 173 174 175 176
    }
    mgb_assert(0);
}

OP_TRAIT_REG(BatchConvBias, BatchConvBias)
M
Megvii Engine Team 已提交
177 178 179 180
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batch_conv_bias
}  // namespace
181

M
Megvii Engine Team 已提交
182 183 184
namespace {
namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
185
    auto&& pool = static_cast<const Pooling&>(def);
186
    OperatorNodeConfig config{pool.make_name()};
187
    return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
188
}
M
Megvii Engine Team 已提交
189 190 191
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace pooling
}  // namespace
192

M
Megvii Engine Team 已提交
193 194 195
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
196 197
    auto&& matmul = static_cast<const MatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
198
    OperatorNodeConfig config{matmul.make_name()};
M
Megvii Engine Team 已提交
199 200
    return opr::MatrixMul::make(
            inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
201
}
M
Megvii Engine Team 已提交
202
OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
203 204
}  // namespace matrix_mul
}  // namespace
205

M
Megvii Engine Team 已提交
206 207 208
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
209 210
    auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
211
    OperatorNodeConfig config{matmul.make_name()};
M
Megvii Engine Team 已提交
212 213
    return opr::BatchedMatrixMul::make(
            inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
214 215
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
M
Megvii Engine Team 已提交
216 217 218 219
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batched_matrix_mul
}  // namespace
220

M
Megvii Engine Team 已提交
221 222 223
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
224
    auto&& op = def.cast_final_safe<Dot>();
225
    mgb_assert(inputs.size() == 2);
226 227
    OperatorNodeConfig config{op.make_name()};
    return opr::Dot::make(inputs[0], inputs[1], config);
228
}
M
Megvii Engine Team 已提交
229 230 231
OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace dot
}  // namespace
232

M
Megvii Engine Team 已提交
233 234 235
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
236
    auto&& argsort = static_cast<const Argsort&>(def);
237 238
    OperatorNodeConfig config{argsort.make_name()};
    return opr::Argsort::make(inputs[0], argsort.param(), config);
239
}
M
Megvii Engine Team 已提交
240 241 242
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argsort
}  // namespace
243

M
Megvii Engine Team 已提交
244 245 246
namespace {
namespace argmax {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
247
    auto&& argmax = static_cast<const Argmax&>(def);
248 249
    OperatorNodeConfig config{argmax.make_name()};
    return opr::Argmax::make(inputs[0], argmax.param(), config);
250
}
M
Megvii Engine Team 已提交
251 252 253
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmax
}  // namespace
254

M
Megvii Engine Team 已提交
255 256 257
namespace {
namespace argmin {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
258
    auto&& argmin = static_cast<const Argmin&>(def);
259 260
    OperatorNodeConfig config{argmin.make_name()};
    return opr::Argmin::make(inputs[0], argmin.param(), config);
261
}
M
Megvii Engine Team 已提交
262 263 264
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmin
}  // namespace
265

M
Megvii Engine Team 已提交
266 267 268
namespace {
namespace warp_perspective {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
269
    auto&& warp = static_cast<const WarpPerspective&>(def);
270
    OperatorNodeConfig config{warp.make_name()};
271
    if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
272 273
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], warp.param(), config);
274 275
    } else {
        mgb_assert(inputs.size() == 4);
M
Megvii Engine Team 已提交
276 277
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
278 279 280
    }
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
M
Megvii Engine Team 已提交
281 282 283 284
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace warp_perspective
}  // namespace
285

M
Megvii Engine Team 已提交
286 287 288
namespace {
namespace group_local {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
289 290
    auto&& local = static_cast<const GroupLocal&>(def);
    mgb_assert(inputs.size() == 2);
291 292
    OperatorNodeConfig config{local.make_name()};
    return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
293
}
M
Megvii Engine Team 已提交
294
OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
295 296
}  // namespace group_local
}  // namespace
297

M
Megvii Engine Team 已提交
298 299 300
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
301 302
    auto&& op = static_cast<const IndexingSetOneHot&>(def);
    mgb_assert(inputs.size() == 3);
303
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
304 305
    return opr::IndexingSetOneHot::make(
            inputs[0], inputs[1], inputs[2], op.param(), config);
306 307
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
M
Megvii Engine Team 已提交
308 309 310 311
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace indexing_set_one_hot
}  // namespace
312

M
Megvii Engine Team 已提交
313 314 315
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
316 317
    auto&& op = static_cast<const TypeCvt&>(def);
    mgb_assert(inputs.size() == 1);
318 319
    OperatorNodeConfig config{op.make_name()};
    return opr::TypeCvt::make(inputs[0], op.dtype, config);
320
}
M
Megvii Engine Team 已提交
321 322 323
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace typecvt
}  // namespace
324

M
Megvii Engine Team 已提交
325 326 327
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
328 329
    auto&& op = static_cast<const Concat&>(def);
    cg::OperatorNodeConfig config{op.comp_node};
330
    config.name(op.make_name());
331 332
    return opr::Concat::make(inputs, op.axis, config);
}
M
Megvii Engine Team 已提交
333 334 335
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace concat
}  // namespace
336

M
Megvii Engine Team 已提交
337 338 339
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
340 341 342
    auto&& op = static_cast<const Copy&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
343
    config.name(op.make_name());
344 345
    return opr::Copy::make(inputs[0], config);
}
M
Megvii Engine Team 已提交
346 347 348
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace copy
}  // namespace
349

M
Megvii Engine Team 已提交
350 351 352
namespace {
namespace assert_equal {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
353 354 355 356 357 358
    auto&& op = def.cast_final<AssertEqual>();
    if (inputs.size() == 2) {
        return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
    } else {
        // workaround for MiniGraph, which only allow one opr in the graph
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
359
        return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
360
    }
361
}
362

M
Megvii Engine Team 已提交
363
OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
364 365
}  // namespace assert_equal
}  // namespace
366

M
Megvii Engine Team 已提交
367 368 369
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
370 371
    auto&& op = static_cast<const ROIAlign&>(def);
    mgb_assert(inputs.size() == 2);
372
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
373 374 375
    auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
                        .node()
                        ->owner_opr();
376
    return {opr->output(0), opr->output(1)};
377
}
M
Megvii Engine Team 已提交
378
OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
379 380
}  // namespace roi_align
}  // namespace
381

M
Megvii Engine Team 已提交
382 383 384
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
385 386 387
    auto&& op = static_cast<const Correlation&>(def);
    mgb_assert(inputs.size() == 2);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
388
    return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
389
}
M
Megvii Engine Team 已提交
390
OP_TRAIT_REG(Correlation, Correlation).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
391 392
}  // namespace correlation
}  // namespace
393

394
#if MGB_CUDA
M
Megvii Engine Team 已提交
395 396 397
namespace {
namespace nvof {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
398 399
    auto&& op = static_cast<const NvOf&>(def);
    mgb_assert(inputs.size() == 1);
400 401
    OperatorNodeConfig config{op.make_name()};
    return opr::NvOf::make(inputs[0], op.param(), config);
402
}
M
Megvii Engine Team 已提交
403 404 405
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace nvof
}  // namespace
406 407
#endif

M
Megvii Engine Team 已提交
408 409 410
namespace {
namespace linspace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
411 412 413
    auto&& op = static_cast<const Linspace&>(def);
    mgb_assert(inputs.size() == 3);
    cg::OperatorNodeConfig config{op.comp_node};
414
    config.name(op.make_name());
M
Megvii Engine Team 已提交
415
    return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
416
}
M
Megvii Engine Team 已提交
417
OP_TRAIT_REG(Linspace, Linspace).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
418 419
}  // namespace linspace
}  // namespace
420

M
Megvii Engine Team 已提交
421 422 423
namespace {
namespace eye {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
424 425 426
    auto&& op = static_cast<const Eye&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
427
    config.name(op.make_name());
428 429 430
    opr::Eye::Param param{op.k, op.dtype.enumv()};
    return opr::Eye::make(inputs[0], param, config);
}
M
Megvii Engine Team 已提交
431 432 433
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace eye
}  // namespace
434

M
Megvii Engine Team 已提交
435 436 437
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
438 439
    auto&& op = static_cast<const ROIPooling&>(def);
    mgb_assert(inputs.size() == 3);
440
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
441 442 443 444
    auto* opr =
            opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config)
                    .node()
                    ->owner_opr();
445
    return {opr->output(0), opr->output(1)};
446
}
M
Megvii Engine Team 已提交
447
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
448 449
}  // namespace roi_pooling
}  // namespace
450

M
Megvii Engine Team 已提交
451 452 453
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
454 455
    auto&& op = static_cast<const Remap&>(def);
    mgb_assert(inputs.size() == 2);
456 457
    OperatorNodeConfig config{op.make_name()};
    return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
458
}
M
Megvii Engine Team 已提交
459 460 461
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace remap
}  // namespace
462 463 464

namespace {
auto get_index(
M
Megvii Engine Team 已提交
465 466
        const VarNodeArray& inputs, size_t vidx,
        const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
467 468
    size_t length = mask.size();
    opr::Subtensor::IndexDesc ret(length);
M
Megvii Engine Team 已提交
469
    for (size_t i = 0; i < length; ++i) {
470 471 472 473 474 475
        auto&& [axis, begin, end, step, idx] = mask[i];
        ret[i].axis = axis;
        if (idx) {
            ret[i].idx = inputs[vidx++];
        } else {
            mgb_assert(begin || end || step);
M
Megvii Engine Team 已提交
476 477 478 479 480 481
            if (begin)
                ret[i].begin = inputs[vidx++];
            if (end)
                ret[i].end = inputs[vidx++];
            if (step)
                ret[i].step = inputs[vidx++];
482 483 484 485 486 487 488 489
        }
    }
    mgb_assert(vidx == inputs.size());
    return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]

M
Megvii Engine Team 已提交
490 491 492 493 494 495 496 497 498
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT)                                       \
    namespace NAME##_impl {                                                       \
        auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {    \
            auto&& op = static_cast<const NAME&>(def);                            \
            OperatorNodeConfig config{op.make_name()};                            \
            return opr::NAME::make(                                               \
                    IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
        }                                                                         \
        OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
M
Megvii Engine Team 已提交
499
    }
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516

FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)

#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
M
Megvii Engine Team 已提交
517
}  // anonymous namespace
518

M
Megvii Engine Team 已提交
519 520 521
namespace {
namespace fake_quant {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
522 523
    auto&& op = static_cast<const FakeQuant&>(def);
    mgb_assert(inputs.size() == 3);
524
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
525
    return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
526
}
M
Megvii Engine Team 已提交
527
OP_TRAIT_REG(FakeQuant, FakeQuant).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
528 529
}  // namespace fake_quant
}  // namespace
530

M
Megvii Engine Team 已提交
531 532 533
namespace {
namespace tqt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
M
Megvii Engine Team 已提交
534 535
    auto&& op = static_cast<const TQT&>(def);
    mgb_assert(inputs.size() == 2);
536 537
    OperatorNodeConfig config{op.make_name()};
    return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
M
Megvii Engine Team 已提交
538
}
M
Megvii Engine Team 已提交
539 540 541
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace tqt
}  // namespace
542

M
Megvii Engine Team 已提交
543 544 545
namespace {
namespace elemwise_multi_type {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
546 547
    auto&& op = static_cast<const ElemwiseMultiType&>(def);
    OperatorNodeConfig config{op.dtype};
548
    config.name(op.make_name());
549 550 551
    return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
M
Megvii Engine Team 已提交
552 553 554 555
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace elemwise_multi_type
}  // namespace
556

M
Megvii Engine Team 已提交
557 558 559
namespace {
namespace svd {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
560 561
    auto&& op = static_cast<const SVD&>(def);
    mgb_assert(inputs.size() == 1);
562 563
    OperatorNodeConfig config{op.make_name()};
    return opr::SVD::make(inputs[0], op.param(), config)[0]
M
Megvii Engine Team 已提交
564 565 566
            .node()
            ->owner_opr()
            ->usable_output();
567
}
M
Megvii Engine Team 已提交
568 569 570
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace svd
}  // namespace
571

M
Megvii Engine Team 已提交
572 573 574
namespace {
namespace images2neibs {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
575 576 577 578 579
    auto&& op = static_cast<const Images2Neibs&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Images2Neibs::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
M
Megvii Engine Team 已提交
580 581 582 583 584 585 586 587 588 589 590
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace images2neibs
}  // namespace

namespace {
namespace lsq {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LSQ&>(def);
    mgb_assert(inputs.size() == 4);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
591 592
    return opr::LSQ::make(
            inputs[0], inputs[1], inputs[2], inputs[3], op.param(), config);
M
Megvii Engine Team 已提交
593 594 595 596
}
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace lsq
}  // namespace
597

M
Megvii Engine Team 已提交
598 599 600
namespace {
namespace sliding_window_transpose {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
601 602 603 604 605
    auto&& op = static_cast<const SlidingWindowTranspose&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
M
Megvii Engine Team 已提交
606 607 608 609
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace sliding_window_transpose
}  // namespace
610

M
Megvii Engine Team 已提交
611 612 613 614 615 616 617 618 619 620 621 622
namespace {
namespace cumsum {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Cumsum&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Cumsum::make(inputs[0], op.param(), config);
}

OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace cumsum
}  // namespace

623 624 625 626 627 628 629
namespace padding {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Padding&>(def);
    mgb_assert(inputs.size() == 1);
    return opr::Padding::make(inputs[0], op.param());
}
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
630
}  // namespace padding
631 632 633 634 635 636 637 638

namespace lrn {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LRN&>(def);
    mgb_assert(inputs.size() == 1);
    return opr::LRN::make(inputs[0], op.param());
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
639
}  // namespace lrn
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664

namespace layer_norm {

cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LayerNorm&>(def);
    size_t nr_inp = inputs.size();
    auto p = op.param();
    mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
    OperatorNodeConfig config{op.make_name()};
    if (nr_inp == 3) {
        return opr::LayerNorm::make(
                       inputs[0], inputs[1], inputs[2], op.param(), config)[0]
                .node()
                ->owner_opr();
    } else {
        return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
                .node()
                ->owner_opr();
    }
}

OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();

}  // namespace layer_norm

M
Megvii Engine Team 已提交
665
}  // namespace mgb::imperative