specializations.cpp 24.2 KB
Newer Older
1 2 3
// FIXME: split this file into separate files for each specialized op

#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
4 5
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
6
#include "megbrain/opr/dnn/adaptive_pooling.h"
M
Megvii Engine Team 已提交
7 8
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
9
#include "megbrain/opr/dnn/fake_quant.h"
M
Megvii Engine Team 已提交
10
#include "megbrain/opr/dnn/images2neibs.h"
11
#include "megbrain/opr/dnn/local.h"
12
#include "megbrain/opr/dnn/lrn.h"
M
Megvii Engine Team 已提交
13 14
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
15 16
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
M
Megvii Engine Team 已提交
17
#include "megbrain/opr/dnn/sliding_window_transpose.h"
M
Megvii Engine Team 已提交
18
#include "megbrain/opr/dnn/tqt.h"
19 20 21 22 23 24 25 26 27 28
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

29
#include "../blob_manager_impl.h"
30 31 32 33
#include "../op_trait.h"

namespace mgb::imperative {

M
Megvii Engine Team 已提交
34 35
namespace {
namespace dimshuffle {
36 37 38
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
    std::vector<int> pattern(node->param().pattern_len);
M
Megvii Engine Team 已提交
39
    for (size_t i = 0; i < node->param().pattern_len; ++i) {
40 41 42 43 44
        pattern[i] = node->param().pattern[i];
    }
    return Dimshuffle::make(pattern);
}

M
Megvii Engine Team 已提交
45
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
46
    auto&& ds = static_cast<const Dimshuffle&>(def);
47 48
    OperatorNodeConfig config{ds.make_name()};
    return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
49 50
}

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& ds = static_cast<const Dimshuffle&>(def);
    mgb_assert(
            ds.pattern.size() <= TensorShape::MAX_NDIM,
            "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    TensorShape out_shape;
    if (src.layout.ndim == 0) {
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
    }
    size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
    mgb_assert(
            src.layout.ndim == pattern_ndim,
            "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
            src.layout.ndim);
    size_t idx = 0;
    bool input_used[TensorLayout::MAX_NDIM] = {0};
71
    out_shape.ndim = ds.pattern.size();
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    for (auto i : ds.pattern) {
        if (i < 0) {
            out_shape[idx] = 1;
        } else {
            input_used[i] = true;
            out_shape[idx] = src.layout.shape[i];
        }
        ++idx;
    }
    for (size_t i = 0; i < pattern_ndim; ++i) {
        mgb_assert(
                input_used[i] || src.layout.shape[i] == 1,
                "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
                src.layout.megdnn::TensorShape::to_string().c_str(), i);
    }
    return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& ds = static_cast<const Dimshuffle&>(def);
    mgb_assert(
            ds.pattern.size() <= TensorShape::MAX_NDIM,
            "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto inp_layout = src->layout();
    size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
    mgb_assert(
            inp_layout.ndim == pattern_ndim,
            "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
            inp_layout.ndim);
    TensorLayout out_layout{inp_layout.dtype};
    out_layout.ndim = ds.pattern.size();

    size_t idx = 0;
    bool input_used[TensorLayout::MAX_NDIM] = {0};
    for (auto i : ds.pattern) {
        if (i < 0) {
            out_layout.shape[idx] = 1;
            out_layout.stride[idx] = 1;
        } else {
            input_used[i] = true;
            out_layout.shape[idx] = inp_layout.shape[i];
            out_layout.stride[idx] = inp_layout.stride[i];
        }
        ++idx;
    }
    if (out_layout.is_contiguous()) {
        out_layout.init_contiguous_stride();
    }
    for (size_t i = 0; i < pattern_ndim; ++i) {
        mgb_assert(
                input_used[i] || inp_layout.shape[i] == 1,
                "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
                inp_layout.megdnn::TensorShape::to_string().c_str(), i);
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), out_layout)};
}

135
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
M
Megvii Engine Team 已提交
136 137
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
138
        .apply_on_physical_tensor(apply_on_physical_tensor)
139
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
M
Megvii Engine Team 已提交
140 141 142
        .fallback();
}  // namespace dimshuffle
}  // namespace
143

M
Megvii Engine Team 已提交
144 145 146
namespace {
namespace add_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
147 148 149 150 151 152
    auto&& add_axis = static_cast<const AddAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : add_axis.axis) {
        param.push_back(Desc::make_add(i));
    }
153 154
    OperatorNodeConfig config{add_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
155 156
}

157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<AddAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto olayout = src.layout;
    if (src.layout.ndim == 0) {
        return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
    }
    for (auto&& i : op_def.axis) {
        olayout.add_axis_cont_inplace(i);
    }
    return {{{olayout, src.comp_node}}, true};
}

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op_def = def.cast_final_safe<AddAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto tlayout = src->layout();
    for (auto&& i : op_def.axis) {
        tlayout.add_axis_cont_inplace(i);
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
}

OP_TRAIT_REG(AddAxis, AddAxis)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
191
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
192
        .fallback();
M
Megvii Engine Team 已提交
193 194
}  // namespace add_axis
}  // namespace
195

M
Megvii Engine Team 已提交
196 197 198
namespace {
namespace remove_axis {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
199 200 201 202 203 204
    auto&& remove_axis = static_cast<const RemoveAxis&>(def);
    using Desc = opr::AxisAddRemove::AxisDesc;
    std::vector<Desc> param;
    for (auto&& i : remove_axis.axis) {
        param.push_back(Desc::make_remove(i));
    }
205 206
    OperatorNodeConfig config{remove_axis.make_name()};
    return opr::AxisAddRemove::make(inputs[0], param, config);
207 208
}

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op_def = def.cast_final_safe<RemoveAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto tlayout = src->layout();
    for (auto&& i : op_def.axis) {
        if (tlayout.ndim == 1) {
            mgb_assert(
                    tlayout.shape[0] == 1 && i == 0,
                    "can not remove axis %u from tensor of shape=%s", i,
                    tlayout.megdnn::TensorShape::to_string().c_str());
        } else {
            mgb_assert(
                    i < tlayout.ndim && tlayout.shape[i] == 1,
                    "can not remove axis %u from tensor of shape=%s", i,
                    tlayout.megdnn::TensorShape::to_string().c_str());
            tlayout.remove_axis_inplace(i);
        }
    }
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
}

235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<RemoveAxis>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto olayout = src.layout;
    if (src.layout.ndim == 0) {
        return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
    }
    for (auto&& i : op_def.axis) {
        if (olayout.ndim == 1) {
            mgb_assert(
                    olayout.shape[0] == 1 && i == 0,
                    "can not remove axis %u from tensor of shape=%s", i,
                    olayout.megdnn::TensorShape::to_string().c_str());
        } else {
            mgb_assert(
                    i < olayout.ndim && olayout.shape[i] == 1,
                    "can not remove axis %u from tensor of shape=%s", i,
                    olayout.megdnn::TensorShape::to_string().c_str());
            olayout.remove_axis_inplace(i);
        }
    }
    return {{{olayout, src.comp_node}}, true};
}

262 263 264
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
        .apply_on_var_node(apply_on_var_node)
        .apply_on_physical_tensor(apply_on_physical_tensor)
265
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
266
        .fallback();
M
Megvii Engine Team 已提交
267 268
}  // namespace remove_axis
}  // namespace
269

M
Megvii Engine Team 已提交
270 271 272
namespace {
namespace top_k {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
273
    auto&& topk = static_cast<const TopK&>(def);
274 275
    OperatorNodeConfig config{topk.make_name()};
    return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
M
Megvii Engine Team 已提交
276 277
            .node()
            ->owner_opr();
278 279
}

M
Megvii Engine Team 已提交
280 281 282
OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace top_k
}  // namespace
283

M
Megvii Engine Team 已提交
284 285 286
namespace {
namespace batch_conv_bias {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
287 288
    auto&& conv = static_cast<const BatchConvBias&>(def);
    cg::OperatorNodeConfig config{conv.dtype};
289
    config.name(conv.make_name());
290
    if (inputs.size() == 2) {
M
Megvii Engine Team 已提交
291 292
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], conv.param(), conv.policy(), config);
293
    } else if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
294 295
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
296
    } else if (inputs.size() == 4) {
M
Megvii Engine Team 已提交
297 298 299
        return opr::BatchConvBias::make(
                inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
                config);
300 301 302 303 304
    }
    mgb_assert(0);
}

OP_TRAIT_REG(BatchConvBias, BatchConvBias)
M
Megvii Engine Team 已提交
305 306 307 308
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace batch_conv_bias
}  // namespace
309

M
Megvii Engine Team 已提交
310 311 312
namespace {
namespace argsort {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
313
    auto&& argsort = static_cast<const Argsort&>(def);
314 315
    OperatorNodeConfig config{argsort.make_name()};
    return opr::Argsort::make(inputs[0], argsort.param(), config);
316
}
M
Megvii Engine Team 已提交
317 318 319
OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argsort
}  // namespace
320

M
Megvii Engine Team 已提交
321 322 323
namespace {
namespace argmax {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
324
    auto&& argmax = static_cast<const Argmax&>(def);
325 326
    OperatorNodeConfig config{argmax.make_name()};
    return opr::Argmax::make(inputs[0], argmax.param(), config);
327
}
M
Megvii Engine Team 已提交
328 329 330
OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmax
}  // namespace
331

M
Megvii Engine Team 已提交
332 333 334
namespace {
namespace argmin {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
335
    auto&& argmin = static_cast<const Argmin&>(def);
336 337
    OperatorNodeConfig config{argmin.make_name()};
    return opr::Argmin::make(inputs[0], argmin.param(), config);
338
}
M
Megvii Engine Team 已提交
339 340 341
OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace argmin
}  // namespace
342

M
Megvii Engine Team 已提交
343 344 345
namespace {
namespace warp_perspective {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
346
    auto&& warp = static_cast<const WarpPerspective&>(def);
347
    OperatorNodeConfig config{warp.make_name()};
348
    if (inputs.size() == 3) {
M
Megvii Engine Team 已提交
349 350
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], warp.param(), config);
351 352
    } else {
        mgb_assert(inputs.size() == 4);
M
Megvii Engine Team 已提交
353 354
        return opr::WarpPerspective::make(
                inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
355 356 357
    }
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
M
Megvii Engine Team 已提交
358 359 360 361
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace warp_perspective
}  // namespace
362

M
Megvii Engine Team 已提交
363 364 365
namespace {
namespace group_local {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
366 367
    auto&& local = static_cast<const GroupLocal&>(def);
    mgb_assert(inputs.size() == 2);
368 369
    OperatorNodeConfig config{local.make_name()};
    return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
370
}
M
Megvii Engine Team 已提交
371
OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
372 373
}  // namespace group_local
}  // namespace
374

M
Megvii Engine Team 已提交
375 376 377
namespace {
namespace typecvt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
378 379
    auto&& op = static_cast<const TypeCvt&>(def);
    mgb_assert(inputs.size() == 1);
380 381
    OperatorNodeConfig config{op.make_name()};
    return opr::TypeCvt::make(inputs[0], op.dtype, config);
382
}
M
Megvii Engine Team 已提交
383 384 385
OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace typecvt
}  // namespace
386

M
Megvii Engine Team 已提交
387 388 389
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
390 391
    auto&& op = static_cast<const Concat&>(def);
    cg::OperatorNodeConfig config{op.comp_node};
392
    config.name(op.make_name());
393 394
    return opr::Concat::make(inputs, op.axis, config);
}
M
Megvii Engine Team 已提交
395 396 397
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace concat
}  // namespace
398

M
Megvii Engine Team 已提交
399 400 401
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
402 403 404
    auto&& op = static_cast<const Copy&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
405
    config.name(op.make_name());
406 407
    return opr::Copy::make(inputs[0], config);
}
M
Megvii Engine Team 已提交
408 409 410
OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace copy
}  // namespace
411

M
Megvii Engine Team 已提交
412 413 414
namespace {
namespace assert_equal {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
415 416 417 418 419 420
    auto&& op = def.cast_final<AssertEqual>();
    if (inputs.size() == 2) {
        return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
    } else {
        // workaround for MiniGraph, which only allow one opr in the graph
        mgb_assert(inputs.size() == 3);
M
Megvii Engine Team 已提交
421
        return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
422
    }
423
}
424

M
Megvii Engine Team 已提交
425
OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
426 427
}  // namespace assert_equal
}  // namespace
428

M
Megvii Engine Team 已提交
429 430 431
namespace {
namespace correlation {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
432 433 434
    auto&& op = static_cast<const Correlation&>(def);
    mgb_assert(inputs.size() == 2);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
435
    return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
436
}
M
Megvii Engine Team 已提交
437
OP_TRAIT_REG(Correlation, Correlation).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
438 439
}  // namespace correlation
}  // namespace
440

441
#if MGB_CUDA
M
Megvii Engine Team 已提交
442 443 444
namespace {
namespace nvof {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
445 446
    auto&& op = static_cast<const NvOf&>(def);
    mgb_assert(inputs.size() == 1);
447 448
    OperatorNodeConfig config{op.make_name()};
    return opr::NvOf::make(inputs[0], op.param(), config);
449
}
M
Megvii Engine Team 已提交
450 451 452
OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace nvof
}  // namespace
453 454
#endif

M
Megvii Engine Team 已提交
455 456 457
namespace {
namespace linspace {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
458 459 460
    auto&& op = static_cast<const Linspace&>(def);
    mgb_assert(inputs.size() == 3);
    cg::OperatorNodeConfig config{op.comp_node};
461
    config.name(op.make_name());
M
Megvii Engine Team 已提交
462
    return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
463
}
M
Megvii Engine Team 已提交
464
OP_TRAIT_REG(Linspace, Linspace).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
465 466
}  // namespace linspace
}  // namespace
467

M
Megvii Engine Team 已提交
468 469 470
namespace {
namespace eye {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
471 472 473
    auto&& op = static_cast<const Eye&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.comp_node};
474
    config.name(op.make_name());
475 476 477
    opr::Eye::Param param{op.k, op.dtype.enumv()};
    return opr::Eye::make(inputs[0], param, config);
}
M
Megvii Engine Team 已提交
478 479 480
OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace eye
}  // namespace
481

482 483 484 485 486 487 488 489 490 491 492 493 494
namespace {
namespace diag {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Diag&>(def);
    mgb_assert(inputs.size() == 1);
    cg::OperatorNodeConfig config{op.make_name()};
    opr::Diag::Param param{op.k};
    return opr::Diag::make(inputs[0], param, config);
}
OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace diag
}  // namespace

M
Megvii Engine Team 已提交
495 496 497
namespace {
namespace remap {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
498 499
    auto&& op = static_cast<const Remap&>(def);
    mgb_assert(inputs.size() == 2);
500 501
    OperatorNodeConfig config{op.make_name()};
    return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
502
}
M
Megvii Engine Team 已提交
503 504 505
OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace remap
}  // namespace
506 507 508

namespace {
auto get_index(
M
Megvii Engine Team 已提交
509 510
        const VarNodeArray& inputs, size_t vidx,
        const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
511 512
    size_t length = mask.size();
    opr::Subtensor::IndexDesc ret(length);
M
Megvii Engine Team 已提交
513
    for (size_t i = 0; i < length; ++i) {
514 515 516 517 518 519
        auto&& [axis, begin, end, step, idx] = mask[i];
        ret[i].axis = axis;
        if (idx) {
            ret[i].idx = inputs[vidx++];
        } else {
            mgb_assert(begin || end || step);
M
Megvii Engine Team 已提交
520 521 522 523 524 525
            if (begin)
                ret[i].begin = inputs[vidx++];
            if (end)
                ret[i].end = inputs[vidx++];
            if (step)
                ret[i].step = inputs[vidx++];
526 527 528 529 530 531 532 533
        }
    }
    mgb_assert(vidx == inputs.size());
    return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]

M
Megvii Engine Team 已提交
534 535 536 537 538 539 540 541 542
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT)                                       \
    namespace NAME##_impl {                                                       \
        auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {    \
            auto&& op = static_cast<const NAME&>(def);                            \
            OperatorNodeConfig config{op.make_name()};                            \
            return opr::NAME::make(                                               \
                    IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
        }                                                                         \
        OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
M
Megvii Engine Team 已提交
543
    }
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559

FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)

#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
M
Megvii Engine Team 已提交
560
}  // anonymous namespace
561

M
Megvii Engine Team 已提交
562 563 564
namespace {
namespace fake_quant {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
565 566
    auto&& op = static_cast<const FakeQuant&>(def);
    mgb_assert(inputs.size() == 3);
567
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
568
    return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
569
}
M
Megvii Engine Team 已提交
570
OP_TRAIT_REG(FakeQuant, FakeQuant).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
571 572
}  // namespace fake_quant
}  // namespace
573

M
Megvii Engine Team 已提交
574 575 576
namespace {
namespace tqt {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
M
Megvii Engine Team 已提交
577 578
    auto&& op = static_cast<const TQT&>(def);
    mgb_assert(inputs.size() == 2);
579 580
    OperatorNodeConfig config{op.make_name()};
    return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
M
Megvii Engine Team 已提交
581
}
M
Megvii Engine Team 已提交
582 583 584
OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace tqt
}  // namespace
585

M
Megvii Engine Team 已提交
586 587 588
namespace {
namespace elemwise_multi_type {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
589 590
    auto&& op = static_cast<const ElemwiseMultiType&>(def);
    OperatorNodeConfig config{op.dtype};
591
    config.name(op.make_name());
592 593 594
    return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
M
Megvii Engine Team 已提交
595 596 597 598
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace elemwise_multi_type
}  // namespace
599

M
Megvii Engine Team 已提交
600 601 602
namespace {
namespace svd {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
603 604
    auto&& op = static_cast<const SVD&>(def);
    mgb_assert(inputs.size() == 1);
605 606
    OperatorNodeConfig config{op.make_name()};
    return opr::SVD::make(inputs[0], op.param(), config)[0]
M
Megvii Engine Team 已提交
607 608 609
            .node()
            ->owner_opr()
            ->usable_output();
610
}
M
Megvii Engine Team 已提交
611 612 613
OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace svd
}  // namespace
614

M
Megvii Engine Team 已提交
615 616 617
namespace {
namespace images2neibs {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
618 619 620 621 622
    auto&& op = static_cast<const Images2Neibs&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Images2Neibs::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(Images2Neibs, Images2Neibs)
M
Megvii Engine Team 已提交
623 624 625 626 627 628 629 630 631 632 633
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace images2neibs
}  // namespace

namespace {
namespace lsq {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LSQ&>(def);
    mgb_assert(inputs.size() == 4);
    OperatorNodeConfig config{op.make_name()};
M
Megvii Engine Team 已提交
634 635
    return opr::LSQ::make(
            inputs[0], inputs[1], inputs[2], inputs[3], op.param(), config);
M
Megvii Engine Team 已提交
636 637 638 639
}
OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace lsq
}  // namespace
640

M
Megvii Engine Team 已提交
641 642 643
namespace {
namespace sliding_window_transpose {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
644 645 646 647 648
    auto&& op = static_cast<const SlidingWindowTranspose&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
M
Megvii Engine Team 已提交
649 650 651 652
        .apply_on_var_node(apply_on_var_node)
        .fallback();
}  // namespace sliding_window_transpose
}  // namespace
653

654 655 656 657 658 659 660 661 662 663 664
namespace {
namespace cumsum {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const Cumsum&>(def);
    OperatorNodeConfig config{op.make_name()};
    return opr::Cumsum::make(inputs[0], op.param(), config);
}

OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
}  // namespace cumsum
}  // namespace
665 666 667 668 669 670 671
namespace lrn {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const LRN&>(def);
    mgb_assert(inputs.size() == 1);
    return opr::LRN::make(inputs[0], op.param());
}
OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
M
Megvii Engine Team 已提交
672
}  // namespace lrn
673

M
Megvii Engine Team 已提交
674
}  // namespace mgb::imperative