specializations.cpp 28.8 KB
Newer Older
1
/**
2
 * \file imperative/src/impl/ops/specialzations.cpp
3 4
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
M
Megvii Engine Team 已提交
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

// FIXME: split this file into separate files for each specialized op

#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
16 17
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
18
#include "megbrain/opr/dnn/adaptive_pooling.h"
M
Megvii Engine Team 已提交
19 20
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
21
#include "megbrain/opr/dnn/fake_quant.h"
M
Megvii Engine Team 已提交
22
#include "megbrain/opr/dnn/images2neibs.h"
23
#include "megbrain/opr/dnn/layer_norm.h"
24
#include "megbrain/opr/dnn/local.h"
25
#include "megbrain/opr/dnn/lrn.h"
M
Megvii Engine Team 已提交
26 27
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
28 29
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
M
Megvii Engine Team 已提交
30
#include "megbrain/opr/dnn/sliding_window_transpose.h"
M
Megvii Engine Team 已提交
31
#include "megbrain/opr/dnn/tqt.h"
32 33 34 35 36 37 38 39 40 41
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

42
#include "../blob_manager_impl.h"
43 44 45 46
#include "../op_trait.h"

namespace mgb::imperative {

M
Megvii Engine Team 已提交
47 48
namespace {
namespace dimshuffle {
49 50 51
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
    std::vector<int> pattern(node->param().pattern_len);
M
Megvii Engine Team 已提交
52
    for (size_t i = 0; i < node->param().pattern_len; ++i) {
53 54 55 56 57
        pattern[i] = node->param().pattern[i];
    }
    return Dimshuffle::make(pattern);
}

M
Megvii Engine Team 已提交
58
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
59
    auto&& ds = static_cast<const Dimshuffle&>(def);
60 61
    OperatorNodeConfig config{ds.make_name()};
    return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
62 63
}

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& ds = static_cast<const Dimshuffle&>(def);
    mgb_assert(
            ds.pattern.size() <= TensorShape::MAX_NDIM,
            "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto inp_layout = src->layout();
    size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
    mgb_assert(
            inp_layout.ndim == pattern_ndim,
            "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
            inp_layout.ndim);
    TensorLayout out_layout{inp_layout.dtype};
    out_layout.ndim = ds.pattern.size();

    size_t idx = 0;
    bool input_used[TensorLayout::MAX_NDIM] = {0};
    for (auto i : ds.pattern) {
        if (i < 0) {
            out_layout.shape[idx] = 1;
            out_layout.stride[idx] = 1;
        } else {
            input_used[i] = true;
            out_layout.shape[idx] = inp_layout.shape[i];
            out_layout.stride[idx] = inp_layout.stride[i];
        }
        ++idx;
    }
    if (out_layout.is_contiguous()) {
        out_layout.init_contiguous_stride();
    }
    for (size_t i = 0; i < pattern_ndim; ++i) {
        mgb_assert(
                input_used[i] || inp_layout.shape[i] == 1,
                "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
                inp_layout.megdnn::TensorShape::to_string().c_str(), i);
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), out_layout)};
}

109
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
M
Megvii Engine Team 已提交
110 111
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
112
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
113 114 115
        .fallback();
}  // namespace dimshuffle
}  // namespace
116

M
Megvii Engine Team 已提交
117 118 119
namespace {
namespace add_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
120 121 122 123 124 125
    auto&& add_axis = static_cast<const AddAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : add_axis.axis) {
        param.push_back(Desc::make_add(i));
    }
126 127
    OperatorNodeConfig config{add_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
128 129
}

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op_def = def.cast_final_safe<AddAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto tlayout = src->layout();
    for (auto&& i : op_def.axis) {
        tlayout.add_axis_cont_inplace(i);
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
}

OP_TRAIT_REG(AddAxis, AddAxis)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
M
Megvii Engine Team 已提交
149 150
}  // namespace add_axis
}  // namespace
151

M
Megvii Engine Team 已提交
152 153 154
namespace {
namespace remove_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
155 156 157 158 159 160
    auto&& remove_axis = static_cast<const RemoveAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : remove_axis.axis) {
        param.push_back(Desc::make_remove(i));
    }
161 162
    OperatorNodeConfig config{remove_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
163 164
}

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op_def = def.cast_final_safe<RemoveAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto tlayout = src->layout();
    for (auto&& i : op_def.axis) {
        if (tlayout.ndim == 1) {
            mgb_assert(
                    tlayout.shape[0] == 1 && i == 0,
                    "can not remove axis %u from tensor of shape=%s", i,
                    tlayout.megdnn::TensorShape::to_string().c_str());
        } else {
            mgb_assert(
                    i < tlayout.ndim && tlayout.shape[i] == 1,
                    "can not remove axis %u from tensor of shape=%s", i,
                    tlayout.megdnn::TensorShape::to_string().c_str());
            tlayout.remove_axis_inplace(i);
        }
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
}

OP_TRAIT_REG(RemoveAxis, RemoveAxis)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
M
Megvii Engine Team 已提交
195 196
}  // namespace remove_axis
}  // namespace
197

M
Megvii Engine Team 已提交
198 199 200
namespace {
namespace top_k {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
201
    auto&& topk = static_cast<const TopK&>(def);
202 203
    OperatorNodeConfig config{topk.make_name()};
    return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
M
Megvii Engine Team 已提交
204 205
            .node()
            ->owner_opr();
206 207
}

M
Megvii Engine Team 已提交
208 209 210
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace top_k
}  // namespace
211

M
Megvii Engine Team 已提交
212 213 214
namespace {
namespace adaptive_pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
215
    auto&& pool = static_cast<const AdaptivePooling&>(def);
216
    OperatorNodeConfig config{pool.make_name()};
M
Megvii Engine Team 已提交
217
    return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
218 219 220
}

OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
M
Megvii Engine Team 已提交
221 222 223 224
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace adaptive_pooling
}  // namespace
225

M
Megvii Engine Team 已提交
226 227 228
namespace {
namespace batch_conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
229 230
    auto&& conv = static_cast<const BatchConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
231
    config.name(conv.make_name());
232
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
233 234
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
235
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
236 237
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
238
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
239 240 241
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
242 243 244 245 246
    }
    mgb_assert(0);
}

OP_TRAIT_REG(BatchConvBias, BatchConvBias)
M
Megvii Engine Team 已提交
247 248 249 250
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batch_conv_bias
}  // namespace
251

M
Megvii Engine Team 已提交
252 253 254
namespace {
namespace pooling {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
255
    auto&& pool = static_cast<const Pooling&>(def);
256
    OperatorNodeConfig config{pool.make_name()};
257
    return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
258
}
M
Megvii Engine Team 已提交
259 260 261
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace pooling
}  // namespace
262

M
Megvii Engine Team 已提交
263 264 265
namespace {
namespace matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
266 267
    auto&& matmul = static_cast<const MatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
268
    OperatorNodeConfig config{matmul.make_name()};
M
Megvii Engine Team 已提交
269 270
    return opr::MatrixMul::make(
            inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
271
}
M
Megvii Engine Team 已提交
272
OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
273 274
}  // namespace matrix_mul
}  // namespace
275

M
Megvii Engine Team 已提交
276 277 278
namespace {
namespace batched_matrix_mul {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
279 280
    auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
    mgb_assert(inputs.size() == 2);
281
    OperatorNodeConfig config{matmul.make_name()};
M
Megvii Engine Team 已提交
282 283
    return opr::BatchedMatrixMul::make(
            inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
284 285
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
M
Megvii Engine Team 已提交
286 287 288 289
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batched_matrix_mul
}  // namespace
290

M
Megvii Engine Team 已提交
291 292 293
namespace {
namespace dot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
294
    auto&& op = def.cast_final_safe<Dot>();
295
    mgb_assert(inputs.size() == 2);
296 297
    OperatorNodeConfig config{op.make_name()};
    return opr::Dot::make(inputs[0], inputs[1], config);
298
}
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362

// std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
//     auto* node = &node_->cast_final_safe<opr::Dot>();
//     return Dot::make(node->param());
// }

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto a = inputs[0]->layout();
    auto comp_node = inputs[0]->comp_node();
    using TensorND = megdnn::TensorND;
    SmallVector<TensorND> inp_tensornds;
    inp_tensornds.reserve(inputs.size());
    auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::Dot>(comp_node);
    for (unsigned i = 0; i < inputs.size(); ++i) {
        auto dnn_ten = inputs[i]->dnn_tensor();
        inp_tensornds.push_back(dnn_ten);
    }
    TensorLayout oup_layout{inputs[0]->dtype()};
    auto inp1_tensor = inputs[0]->dnn_tensor();
    auto inp2_tensor = inputs[1]->dnn_tensor();
    dnn_opr->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);

    if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
        auto fill_opr = opr::intl::create_megdnn_opr<megdnn::Fill>(comp_node);
        DeviceTensorND out =
                BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
        fill_opr->param() = 0;
        fill_opr->exec(out.as_megdnn(), {});
        return {Tensor::make(out)};
    }

    auto wk_size = dnn_opr->get_workspace_in_bytes(
            inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);

    DeviceTensorND out_devtensor =
            BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
    TensorLayout wk_layout{TensorShape{wk_size}, inputs[0]->dtype()};
    DeviceTensorND workspace =
            BlobManager::inst()->alloc_workspace_with_defrag(comp_node, wk_layout);
    megdnn::Workspace dnn_wk(workspace.raw_ptr(), wk_size);

    dnn_opr->exec(
            inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);

    return {Tensor::make(out_devtensor)};
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<Dot>();
    SmallVector<LogicalTensorDesc> dests(1);
    dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
    dests[0].comp_node = inputs[0].comp_node;
    return {dests, true};
}

OP_TRAIT_REG(Dot, Dot, opr::Dot)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();

M
Megvii Engine Team 已提交
363 364
}  // namespace dot
}  // namespace
365

M
Megvii Engine Team 已提交
366 367 368
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
369
    auto&& argsort = static_cast<const Argsort&>(def);
370 371
    OperatorNodeConfig config{argsort.make_name()};
    return opr::Argsort::make(inputs[0], argsort.param(), config);
372
}
M
Megvii Engine Team 已提交
373 374 375
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argsort
}  // namespace
376

M
Megvii Engine Team 已提交
377 378 379
namespace {
namespace argmax {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
380
    auto&& argmax = static_cast<const Argmax&>(def);
381 382
    OperatorNodeConfig config{argmax.make_name()};
    return opr::Argmax::make(inputs[0], argmax.param(), config);
383
}
M
Megvii Engine Team 已提交
384 385 386
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmax
}  // namespace
387

M
Megvii Engine Team 已提交
388 389 390
namespace {
namespace argmin {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
391
    auto&& argmin = static_cast<const Argmin&>(def);
392 393
    OperatorNodeConfig config{argmin.make_name()};
    return opr::Argmin::make(inputs[0], argmin.param(), config);
394
}
M
Megvii Engine Team 已提交
395 396 397
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmin
}  // namespace
398

M
Megvii Engine Team 已提交
399 400 401
namespace {
namespace warp_perspective {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
402
    auto&& warp = static_cast<const WarpPerspective&>(def);
403
    OperatorNodeConfig config{warp.make_name()};
404
    if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
405 406
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], warp.param(), config);
407 408
    } else {
        mgb_assert(inputs.size() == 4);
M
Megvii Engine Team 已提交
409 410
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
411 412 413
    }
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
M
Megvii Engine Team 已提交
414 415 416 417
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace warp_perspective
}  // namespace
418

M
Megvii Engine Team 已提交
419 420 421
namespace {
namespace group_local {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
422 423
    auto&& local = static_cast<const GroupLocal&>(def);
    mgb_assert(inputs.size() == 2);
424 425
    OperatorNodeConfig config{local.make_name()};
    return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
426
}
M
Megvii Engine Team 已提交
427
OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
428 429
}  // namespace group_local
}  // namespace
430

M
Megvii Engine Team 已提交
431 432 433
namespace {
namespace indexing_set_one_hot {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
434 435
    auto&& op = static_cast<const IndexingSetOneHot&>(def);
    mgb_assert(inputs.size() == 3);
436
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
437 438
    return opr::IndexingSetOneHot::make(
            inputs[0], inputs[1], inputs[2], op.param(), config);
439 440
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
M
Megvii Engine Team 已提交
441 442 443 444
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace indexing_set_one_hot
}  // namespace
445

M
Megvii Engine Team 已提交
446 447 448
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
449 450
    auto&& op = static_cast<const TypeCvt&>(def);
    mgb_assert(inputs.size() == 1);
451 452
    OperatorNodeConfig config{op.make_name()};
    return opr::TypeCvt::make(inputs[0], op.dtype, config);
453
}
M
Megvii Engine Team 已提交
454 455 456
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace typecvt
}  // namespace
457

M
Megvii Engine Team 已提交
458 459 460
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
461 462
    auto&& op = static_cast<const Concat&>(def);
    cg::OperatorNodeConfig config{op.comp_node};
463
    config.name(op.make_name());
464 465
    return opr::Concat::make(inputs, op.axis, config);
}
M
Megvii Engine Team 已提交
466 467 468
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace concat
}  // namespace
469

M
Megvii Engine Team 已提交
470 471 472
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
473 474 475
    auto&& op = static_cast<const Copy&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
476
    config.name(op.make_name());
477 478
    return opr::Copy::make(inputs[0], config);
}
M
Megvii Engine Team 已提交
479 480 481
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace copy
}  // namespace
482

M
Megvii Engine Team 已提交
483 484 485
namespace {
namespace assert_equal {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
486 487 488 489 490 491
    auto&& op = def.cast_final<AssertEqual>();
    if (inputs.size() == 2) {
        return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
    } else {
        // workaround for MiniGraph, which only allow one opr in the graph
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
492
        return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
493
    }
494
}
495

M
Megvii Engine Team 已提交
496
OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
497 498
}  // namespace assert_equal
}  // namespace
499

M
Megvii Engine Team 已提交
500 501 502
namespace {
namespace roi_align {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
503 504
    auto&& op = static_cast<const ROIAlign&>(def);
    mgb_assert(inputs.size() == 2);
505
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
506 507 508
    auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
                        .node()
                        ->owner_opr();
509
    return {opr->output(0), opr->output(1)};
510
}
M
Megvii Engine Team 已提交
511
OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
512 513
}  // namespace roi_align
}  // namespace
514

M
Megvii Engine Team 已提交
515 516 517
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
518 519 520
    auto&& op = static_cast<const Correlation&>(def);
    mgb_assert(inputs.size() == 2);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
521
    return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
522
}
M
Megvii Engine Team 已提交
523
OP_TRAIT_REG(Correlation, Correlation).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
524 525
}  // namespace correlation
}  // namespace
526

527
#if MGB_CUDA
M
Megvii Engine Team 已提交
528 529 530
namespace {
namespace nvof {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
531 532
    auto&& op = static_cast<const NvOf&>(def);
    mgb_assert(inputs.size() == 1);
533 534
    OperatorNodeConfig config{op.make_name()};
    return opr::NvOf::make(inputs[0], op.param(), config);
535
}
M
Megvii Engine Team 已提交
536 537 538
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace nvof
}  // namespace
539 540
#endif

M
Megvii Engine Team 已提交
541 542 543
namespace {
namespace linspace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
544 545 546
    auto&& op = static_cast<const Linspace&>(def);
    mgb_assert(inputs.size() == 3);
    cg::OperatorNodeConfig config{op.comp_node};
547
    config.name(op.make_name());
M
Megvii Engine Team 已提交
548
    return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
549
}
M
Megvii Engine Team 已提交
550
OP_TRAIT_REG(Linspace, Linspace).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
551 552
}  // namespace linspace
}  // namespace
553

M
Megvii Engine Team 已提交
554 555 556
namespace {
namespace eye {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
557 558 559
    auto&& op = static_cast<const Eye&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
560
    config.name(op.make_name());
561 562 563
    opr::Eye::Param param{op.k, op.dtype.enumv()};
    return opr::Eye::make(inputs[0], param, config);
}
M
Megvii Engine Team 已提交
564 565 566
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace eye
}  // namespace
567

568 569 570 571 572 573 574 575 576 577 578 579 580
namespace {
namespace diag {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Diag&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.make_name()};
    opr::Diag::Param param{op.k};
    return opr::Diag::make(inputs[0], param, config);
}
OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace diag
}  // namespace

M
Megvii Engine Team 已提交
581 582 583
namespace {
namespace roi_pooling {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
584 585
    auto&& op = static_cast<const ROIPooling&>(def);
    mgb_assert(inputs.size() == 3);
586
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
587 588 589 590
    auto* opr =
            opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config)
                    .node()
                    ->owner_opr();
591
    return {opr->output(0), opr->output(1)};
592
}
M
Megvii Engine Team 已提交
593
OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
594 595
}  // namespace roi_pooling
}  // namespace
596

M
Megvii Engine Team 已提交
597 598 599
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
600 601
    auto&& op = static_cast<const Remap&>(def);
    mgb_assert(inputs.size() == 2);
602 603
    OperatorNodeConfig config{op.make_name()};
    return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
604
}
M
Megvii Engine Team 已提交
605 606 607
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace remap
}  // namespace
608 609 610

namespace {
auto get_index(
M
Megvii Engine Team 已提交
611 612
        const VarNodeArray& inputs, size_t vidx,
        const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
613 614
    size_t length = mask.size();
    opr::Subtensor::IndexDesc ret(length);
M
Megvii Engine Team 已提交
615
    for (size_t i = 0; i < length; ++i) {
616 617 618 619 620 621
        auto&& [axis, begin, end, step, idx] = mask[i];
        ret[i].axis = axis;
        if (idx) {
            ret[i].idx = inputs[vidx++];
        } else {
            mgb_assert(begin || end || step);
M
Megvii Engine Team 已提交
622 623 624 625 626 627
            if (begin)
                ret[i].begin = inputs[vidx++];
            if (end)
                ret[i].end = inputs[vidx++];
            if (step)
                ret[i].step = inputs[vidx++];
628 629 630 631 632 633 634 635
        }
    }
    mgb_assert(vidx == inputs.size());
    return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]

M
Megvii Engine Team 已提交
636 637 638 639 640 641 642 643 644
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT)                                       \
    namespace NAME##_impl {                                                       \
        auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {    \
            auto&& op = static_cast<const NAME&>(def);                            \
            OperatorNodeConfig config{op.make_name()};                            \
            return opr::NAME::make(                                               \
                    IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
        }                                                                         \
        OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
M
Megvii Engine Team 已提交
645
    }
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662

FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)

#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
M
Megvii Engine Team 已提交
663
}  // anonymous namespace
664

M
Megvii Engine Team 已提交
665 666 667
namespace {
namespace fake_quant {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
668 669
    auto&& op = static_cast<const FakeQuant&>(def);
    mgb_assert(inputs.size() == 3);
670
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
671
    return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
672
}
M
Megvii Engine Team 已提交
673
OP_TRAIT_REG(FakeQuant, FakeQuant).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
674 675
}  // namespace fake_quant
}  // namespace
676

M
Megvii Engine Team 已提交
677 678 679
namespace {
namespace tqt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
M
Megvii Engine Team 已提交
680 681
    auto&& op = static_cast<const TQT&>(def);
    mgb_assert(inputs.size() == 2);
682 683
    OperatorNodeConfig config{op.make_name()};
    return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
M
Megvii Engine Team 已提交
684
}
M
Megvii Engine Team 已提交
685 686 687
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace tqt
}  // namespace
688

M
Megvii Engine Team 已提交
689 690 691
namespace {
namespace elemwise_multi_type {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
692 693
    auto&& op = static_cast<const ElemwiseMultiType&>(def);
    OperatorNodeConfig config{op.dtype};
694
    config.name(op.make_name());
695 696 697
    return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
M
Megvii Engine Team 已提交
698 699 700 701
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace elemwise_multi_type
}  // namespace
702

M
Megvii Engine Team 已提交
703 704 705
namespace {
namespace svd {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
706 707
    auto&& op = static_cast<const SVD&>(def);
    mgb_assert(inputs.size() == 1);
708 709
    OperatorNodeConfig config{op.make_name()};
    return opr::SVD::make(inputs[0], op.param(), config)[0]
M
Megvii Engine Team 已提交
710 711 712
            .node()
            ->owner_opr()
            ->usable_output();
713
}
M
Megvii Engine Team 已提交
714 715 716
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace svd
}  // namespace
717

M
Megvii Engine Team 已提交
718 719 720
namespace {
namespace images2neibs {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
721 722 723 724 725
    auto&& op = static_cast<const Images2Neibs&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Images2Neibs::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
M
Megvii Engine Team 已提交
726 727 728 729 730 731 732 733 734 735 736
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace images2neibs
}  // namespace

namespace {
namespace lsq {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LSQ&>(def);
    mgb_assert(inputs.size() == 4);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
737 738
    return opr::LSQ::make(
            inputs[0], inputs[1], inputs[2], inputs[3], op.param(), config);
M
Megvii Engine Team 已提交
739 740 741 742
}
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace lsq
}  // namespace
743

M
Megvii Engine Team 已提交
744 745 746
namespace {
namespace sliding_window_transpose {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
747 748 749 750 751
    auto&& op = static_cast<const SlidingWindowTranspose&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
M
Megvii Engine Team 已提交
752 753 754 755
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace sliding_window_transpose
}  // namespace
756

M
Megvii Engine Team 已提交
757 758 759 760 761 762 763 764 765 766 767 768
namespace {
namespace cumsum {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Cumsum&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Cumsum::make(inputs[0], op.param(), config);
}

OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace cumsum
}  // namespace

769 770 771 772 773 774 775
namespace padding {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Padding&>(def);
    mgb_assert(inputs.size() == 1);
    return opr::Padding::make(inputs[0], op.param());
}
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
776
}  // namespace padding
777 778 779 780 781 782 783 784

namespace lrn {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LRN&>(def);
    mgb_assert(inputs.size() == 1);
    return opr::LRN::make(inputs[0], op.param());
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
785
}  // namespace lrn
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810

namespace layer_norm {

cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LayerNorm&>(def);
    size_t nr_inp = inputs.size();
    auto p = op.param();
    mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
    OperatorNodeConfig config{op.make_name()};
    if (nr_inp == 3) {
        return opr::LayerNorm::make(
                       inputs[0], inputs[1], inputs[2], op.param(), config)[0]
                .node()
                ->owner_opr();
    } else {
        return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
                .node()
                ->owner_opr();
    }
}

OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();

}  // namespace layer_norm

M
Megvii Engine Team 已提交
811
}  // namespace mgb::imperative