indexing.cpp 21.6 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/indexing.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/indexing.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/graph/grad_impl.h"
14 15 16 17 18 19 20 21 22
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"

#include "./internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;

namespace {
M
Megvii Engine Team 已提交
23 24 25 26 27 28 29 30 31 32 33
void check_index_dtype(std::initializer_list<SymbolVar*>& inputs) {
    mgb_assert(inputs.size() >= 2);
    auto iter = inputs.begin();
    ++iter;
    SymbolVar& index = **iter;
    if (index.dtype() != dtype::Int32()) {
        mgb_log_warn(
                "dtype of index in IndexingOneHot must be Int32, "
                "got %s for variable %s; convert to Int32 implicitly",
                index.dtype().name(), index.node()->cname());
        index = opr::TypeCvt::make(index, dtype::Int32());
34
    }
M
Megvii Engine Team 已提交
35
}
36

M
Megvii Engine Team 已提交
37
enum IndexingModifyType { SET, INCR };
38

M
Megvii Engine Team 已提交
39 40
template <typename Opr>
struct IndexingModifyTypeGetter {};
41

M
Megvii Engine Team 已提交
42 43 44
#define REG(op, type)                                                         \
    template <>                                                               \
    struct IndexingModifyTypeGetter<megdnn::op> {                             \
45 46 47 48 49 50 51 52 53 54
        static constexpr IndexingModifyType value = IndexingModifyType::type; \
    };
REG(IndexingIncrMultiAxisVec, INCR)
REG(IncrMeshIndexing, INCR)
REG(BatchedIncrMeshIndexing, INCR)
REG(IndexingSetMultiAxisVec, SET)
REG(SetMeshIndexing, SET)
REG(BatchedSetMeshIndexing, SET)
#undef REG

M
Megvii Engine Team 已提交
55
}  // namespace
56 57 58 59 60

namespace mgb {
namespace opr {
namespace intl {

M
Megvii Engine Team 已提交
61 62 63 64 65 66 67 68 69
template <>
struct MegDNNOprInitInputsModifier<IndexingOneHot> {
    static void apply(
            const IndexingOneHot::Param& param,
            std::initializer_list<SymbolVar*> inputs) {
        MGB_MARK_USED_VAR(param);
        check_index_dtype(inputs);
    }
};
70

M
Megvii Engine Team 已提交
71 72 73 74 75 76
template <>
struct MegDNNOprInitInputsModifier<IndexingSetOneHot>
        : public MegDNNOprInitInputsModifier<IndexingOneHot> {};
}  // namespace intl
}  // namespace opr
}  // namespace mgb
77 78 79 80 81 82 83 84 85

/* ==================== IndexingOneHot ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot);
MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot")

void IndexingOneHot::init_output_dtype() {
    output(0)->dtype(input(0)->dtype());
}

86
#if MGB_ENABLE_GRAD
87 88 89
MGB_IMPL_OPR_GRAD(IndexingOneHot) {
    if (wrt_idx == 0) {
        return IndexingSetOneHot::make(
M
Megvii Engine Team 已提交
90 91 92
                       SymbolVar{opr.input(0)}.fill_retain_dtype(0), opr.input(1),
                       out_grad[0], opr.param())
                .node();
93 94 95
    }
    return InvalidGrad::make(opr, wrt_idx);
}
96
#endif
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116

/* ==================== IndexingSetOneHot ==================== */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingSetOneHot);
MEGDNN_OPR_INIT3(IndexingSetOneHot, "indexing_set_one_hot")

void IndexingSetOneHot::init_output_dtype() {
    output(0)->dtype(input(0)->dtype());
}

void IndexingSetOneHot::add_input_layout_constraint() {
    mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}

void IndexingSetOneHot::mem_plan_fwd_in2out_writable() {
    cg::request_fwd_in2out_writable_if_no_mem_ovelap(this, 0, 0);
}

void IndexingSetOneHot::init_output_static_infer_desc() {
    using namespace cg::static_infer;
M
Megvii Engine Team 已提交
117 118
    auto&& mgr = owner_graph()->static_infer_manager();
    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
119 120 121 122 123 124 125 126 127 128 129 130 131 132
    init_output_static_infer_desc_workspace(false);
}

void IndexingSetOneHot::scn_do_execute() {
    auto &&idata = input(0)->dev_tensor(), &&index = input(1)->dev_tensor(),
         &&odata = output(0)->dev_tensor();

    if (idata.raw_ptr() != odata.raw_ptr()) {
        odata.copy_from_fixlayout(idata);
    } else {
        mgb_assert(odata.layout().eq_layout(idata.layout()));
    }
    mgb_assert(odata.layout().is_contiguous());

M
Megvii Engine Team 已提交
133 134
    megdnn_opr()->exec(
            odata.as_megdnn(), index.as_megdnn(), input(2)->dev_tensor().as_megdnn(),
135 136 137
            intl::get_megdnn_workspace_from_var(output(1)));
}

138
#if MGB_ENABLE_GRAD
139 140 141
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
    SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
    if (wrt_idx == 0) {
M
Megvii Engine Team 已提交
142 143
        return IndexingSetOneHot::make(og, index, sub.fill_retain_dtype(0), opr.param())
                .node();
144 145 146 147 148 149
    }
    if (wrt_idx == 2) {
        return IndexingOneHot::make(og, index, opr.param()).node();
    }
    return InvalidGrad::make(opr, wrt_idx);
}
150
#endif
151 152

size_t IndexingSetOneHot::get_workspace_size_bytes(
M
Megvii Engine Team 已提交
153 154
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
155
    return megdnn_opr()->get_workspace_in_bytes(
M
Megvii Engine Team 已提交
156 157
            {input_shapes[0], input(0)->dtype()}, {input_shapes[1], input(1)->dtype()},
            {input_shapes[2], input(2)->dtype()});
158 159 160 161 162 163 164
}

/* ==================== IndexingRemap ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemap);
MEGDNN_OPR_INIT2(IndexingRemap, "indexing_remap")

void IndexingRemap::init_output_dtype() {
M
Megvii Engine Team 已提交
165 166
    mgb_throw_if(
            input(1)->dtype() != dtype::Int32(), GraphError,
167 168 169 170
            "IndexingRemap requires map input to be int32");
    output(0)->dtype(input(0)->dtype());
}

171
#if MGB_ENABLE_GRAD
172 173 174 175 176
MGB_IMPL_OPR_GRAD(IndexingRemap) {
    if (wrt_idx == 1)
        return InvalidGrad::make(opr, wrt_idx);
    mgb_assert(wrt_idx == 0 && out_grad[0]);
    return IndexingRemapBackward::make(
M
Megvii Engine Team 已提交
177 178
                   out_grad[0], opr.input(1), opr.input(0), opr.param())
            .node();
179
}
180
#endif
181 182 183 184 185

MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward);
MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false);

/* ================= IndexingMultiAxisVecMegDNNOprHolder ================= */
M
Megvii Engine Team 已提交
186
template <class Opr>
187 188 189
Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
        cg::SingleCNOperatorNodeBase& self) {
    auto comp_node = self.comp_node();
190 191
    if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node) {
        m_dnn_opr = intl::create_megdnn_opr<Opr>(comp_node);
M
Megvii Engine Team 已提交
192
        m_dnn_opr->set_error_tracker(static_cast<cg::OperatorNodeBase*>(&self));
193
    }
194
    return *m_dnn_opr;
195 196
}

M
Megvii Engine Team 已提交
197
template <class Opr>
198
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
M
Megvii Engine Team 已提交
199
        const indexing::IndexDesc& index_desc, cg::SingleCNOperatorNodeBase& opr,
200
        VarNode* data, VarNode* value, VarNodeArray idx_arr) {
201
    using namespace cg::static_infer;
202 203 204 205 206 207 208
    DepVal deps = {{data, DepType::SHAPE}, {value, DepType::SHAPE}};

    for (auto&& idx : idx_arr) {
        deps.push_back({idx, DepType::SHAPE});
    }
    auto infer_shape = [this, &index_desc, &opr, nr_idx = idx_arr.size()](
                               TensorShape& dest, const InpVal& inp) {
209 210
        size_t axes[TensorShape::MAX_NDIM], nr_axes = 0;
        auto ndim = inp.val[0].shape().ndim;
M
Megvii Engine Team 已提交
211
        for (auto&& i : reverse_adaptor(index_desc)) {
212
            if (i.idx.node()) {
M
Megvii Engine Team 已提交
213
                axes[nr_axes++] = i.axis.get(ndim);
214 215
            }
        }
216
        mgb_assert(nr_axes == nr_idx);
217 218 219
        if (!nr_axes) {
            dest = {0};
        } else {
220 221 222 223 224
            size_t idx_ndim = 0;
            for (size_t i = 0; i < nr_idx; ++i) {
                idx_ndim = std::max(idx_ndim, inp.val[2 + i].shape().ndim);
            }
            mgb_assert(idx_ndim > 0);
225
            dest = {megdnn_opr(opr).get_workspace_in_bytes(
226
                    inp.val[1].shape(), axes, nr_axes, idx_ndim)};
227 228 229 230
        }
        return true;
    };
    opr.owner_graph()->static_infer_manager().register_shape_infer(
231
            opr.output(1), {SourceType::DEP, deps, infer_shape});
232 233 234 235 236
}

template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
        mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
M
Megvii Engine Team 已提交
237
    deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
238 239 240
}

/* ==================== MultiAxisVecFancyIndexingHelper ==================== */
M
Megvii Engine Team 已提交
241 242 243 244
std::pair<const megdnn::IndexingMultiAxisVec::IndexDesc&, bool> intl::
        MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
                size_t inp_ndim, bool warn_all_scalar) {
    auto&& index = m_megdnn_index_cache;
245
    index.clear();
246
    bool is_empty_shape = false;
M
Megvii Engine Team 已提交
247
    for (auto i : reverse_adaptor(m_input2idxonly_axis_indexer)) {
248
        if (i) {
M
Megvii Engine Team 已提交
249 250
            index.push_back(
                    {i->axis.get(inp_ndim), i->idx.node()->dev_tensor().as_megdnn()});
251
            is_empty_shape |= index.back().vec.layout.is_empty();
252 253 254
        }
    }

255 256
    if (!m_scalar_idx_warn_printed && warn_all_scalar &&
        !this->owner_graph()->options().imperative_proxy_graph) {
257
        bool all_scalar = true;
M
Megvii Engine Team 已提交
258
        for (auto&& i : index) {
259 260 261 262 263 264
            if (!i.vec.layout.is_scalar()) {
                all_scalar = false;
                break;
            }
        }
        if (all_scalar) {
265 266 267
#if MGB_ENABLE_GETENV
            mgb_log_warn(
                    "%s{%s}: no vector indexer; consider using Subtensor "
268 269 270 271
                    "family for better performance; you can set "
                    "MGB_THROW_ON_SCALAR_IDX to throw an exception to help "
                    "tracking the related operator",
                    cname(), dyn_typeinfo()->name);
272 273 274 275 276 277 278
#else
            mgb_log_warn(
                    "%s{%s}: no vector indexer; consider using Subtensor "
                    "family for better performance",
                    cname(), dyn_typeinfo()->name);
#endif
#if MGB_ENABLE_GETENV
M
Megvii Engine Team 已提交
279 280 281 282
            mgb_throw_if(
                    MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), MegBrainError,
                    "vector-indexing operator used with all "
                    "scalar indices");
283
#endif
284 285 286 287 288 289 290
        }

        // always set m_scalar_idx_warn_printed to be true, so we do not print
        // this warning in the future
        m_scalar_idx_warn_printed = true;
    }

291
    return {index, is_empty_shape};
292 293 294
}

/* ==================== IndexingMultiAxisVecBase ==================== */
M
Megvii Engine Team 已提交
295 296 297
template <class Opr>
cg::OperatorNodeBase::NodeProp* IndexingMultiAxisVecBase<Opr>::do_make_node_prop()
        const {
298
    auto prop = Super::do_make_node_prop();
299 300
    using DT = NodeProp::DepType;
    prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY);
M
Megvii Engine Team 已提交
301
    for (auto i : m_input2idxonly_axis_indexer) {
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        if (i) {
            prop->add_dep_type_existing_var(
                    i->idx.node(), NodeProp::DepType::VALUE_ALLOW_EMPTY);
        }
    }
    return prop;
}

template <class Opr>
void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    DepVal deps;

    // shape inference only needs slices
    deps.push_back({input(0), DepType::SHAPE});
    // loop in reverse order because megdnn opr needs ascending axes
M
Megvii Engine Team 已提交
318
    for (size_t i = m_input2idxonly_axis_indexer.size() - 1; i; --i) {
319 320 321 322 323
        if (m_input2idxonly_axis_indexer[i]) {
            deps.push_back({input(i), DepType::SHAPE});
        }
    }
    size_t inp_interval_start = deps.size();
M
Megvii Engine Team 已提交
324
    for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
325 326 327 328 329
        if (!m_input2idxonly_axis_indexer[i]) {
            deps.push_back({input(i), DepType::VALUE});
        }
    }
    auto infer_shape = [this, inp_interval_start](
M
Megvii Engine Team 已提交
330 331
                               TensorShape& dest, const InpVal& inp) {
        auto&& ishp = inp.val[0].shape();
332 333 334 335 336
        auto subspec = fancy_indexing_make_sub_spec(
                {ishp, input(0)->dtype()}, inp, inp_interval_start);
        dest = subspec.layout();
        typename Opr::IndexDescLayoutOnly index_layout;
        size_t indexer_pos = 1;
M
Megvii Engine Team 已提交
337
        for (auto i : reverse_adaptor(m_input2idxonly_axis_indexer)) {
338
            if (i) {
M
Megvii Engine Team 已提交
339 340 341
                index_layout.push_back(
                        {i->axis.get(dest.ndim),
                         {inp.val.at(indexer_pos++).shape(), dtype::Int32()}});
342 343 344 345 346 347
            }
        }
        mgb_assert(indexer_pos == inp_interval_start);
        if (!index_layout.empty()) {
            // index_layout is empty if all indices are intervals
            TensorLayout tmp;
M
Megvii Engine Team 已提交
348
            Opr::deduce_layout({dest, input(0)->dtype()}, index_layout, tmp);
349 350 351 352 353 354
            dest = tmp;
        }
        return true;
    };
    owner_graph()->static_infer_manager().register_shape_infer(
            output(0), {SourceType::DEP, deps, infer_shape});
355 356 357 358 359 360 361
    VarNodeArray idx_arr;
    for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
        if (m_input2idxonly_axis_indexer[i]) {
            idx_arr.push_back(input(i));
        }
    }
    this->register_workspace_infer(index_desc(), *this, input(0), output(0), idx_arr);
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
}

template <class Opr>
void IndexingMultiAxisVecBase<Opr>::record_execute_deps(
        mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
    this->record_megdnn_opr(deps);
}

namespace {
template <class Opr>
struct ShouldWarnOnScalarIndexer {
    static constexpr bool val = false;
};

#define WARN(opr)                                   \
    template <>                                     \
    struct ShouldWarnOnScalarIndexer<megdnn::opr> { \
        static constexpr bool val = true;           \
    }
WARN(IndexingMultiAxisVec);
WARN(IndexingSetMultiAxisVec);
WARN(IndexingIncrMultiAxisVec);
#undef WARN
}  // anonymous namespace

template <class Opr>
void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
389 390 391
    if (output(0)->layout().is_empty()) {
        return;
    }
392 393
    auto inp = input(0)->dev_tensor();
    inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout()));
M
Megvii Engine Team 已提交
394
    auto&& index_desc = make_megdnn_index_desc(
395
            inp.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
M
Megvii Engine Team 已提交
396
    auto&& odev = output(0)->dev_tensor();
397
    if (index_desc.first.empty()) {
398 399
        odev.copy_from_fixlayout(inp);
    } else {
400
        if (!index_desc.second) {
401 402
            // only call megdnn exec if result is not empty
            this->megdnn_opr(*this).exec(
403
                    inp.as_megdnn(), index_desc.first, odev.as_megdnn(),
404 405 406 407 408 409 410 411 412
                    intl::get_megdnn_workspace_from_var(output(1)));
        } else {
            mgb_assert(odev.empty());
        }
    }
}

/* ==================== IndexingModifyMultiAxisVecHelper ==================== */

M
Megvii Engine Team 已提交
413 414
template <class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>::init_output_static_infer_desc() {
415 416 417 418
    using namespace cg::static_infer;
    this->owner_graph()->static_infer_manager().register_shape_infer(
            this->output(0), ShapeInferDesc::make_identity(this->input(0)));

419 420 421 422 423 424 425
    VarNodeArray idx_arr;
    for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++i) {
        if (m_input2idxonly_axis_indexer[i]) {
            idx_arr.push_back(input(i));
        }
    }
    this->register_workspace_infer(index_desc(), *this, input(0), input(1), idx_arr);
426 427
}

M
Megvii Engine Team 已提交
428
template <class Opr>
429 430 431 432
void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
    auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
    auto index_desc = this->make_megdnn_index_desc(
            inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
M
Megvii Engine Team 已提交
433
    if (inp.first.shape().is_empty() || index_desc.second) {
434 435 436 437
        mgb_assert(inp.second.shape().is_empty());
        return;
    }
    if (index_desc.first.empty()) {
438
        using IMT = IndexingModifyType;
M
Megvii Engine Team 已提交
439
        static constexpr auto modify_type = IndexingModifyTypeGetter<Opr>::value;
440 441 442 443
        switch (modify_type) {
            case IMT::SET: {
                inp.first.copy_from_fixlayout(inp.second);
                break;
M
Megvii Engine Team 已提交
444 445 446 447
            }
            case IMT::INCR: {
                megdnn::AddUpdate* add_update =
                        intl::get_megdnn_global_opr<megdnn::AddUpdate>(comp_node());
448 449
                add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn());
                break;
M
Megvii Engine Team 已提交
450 451
            }
            default:
452
                mgb_throw(MegBrainError, "bad modify type");
453 454 455
        }
    } else {
        this->megdnn_opr(*this).exec(
M
Megvii Engine Team 已提交
456
                inp.first.as_megdnn(), inp.second.as_megdnn(), index_desc.first,
457 458 459 460
                intl::get_megdnn_workspace_from_var(output(1)));
    }
}

M
Megvii Engine Team 已提交
461 462 463
template <class Opr>
cg::OperatorNodeBase::NodeProp* intl::IndexingModifyMultiAxisVecHelper<
        Opr>::do_make_node_prop() const {
464 465 466 467 468
    auto prop = Super::do_make_node_prop();
    using DT = NodeProp::DepType;
    // TODO: should also allow input shape is empty if any
    // indexer's shape is empty
    prop->add_dep_type_existing_var(input(1), DT::VALUE_ALLOW_EMPTY);
M
Megvii Engine Team 已提交
469
    for (auto i : m_input2idxonly_axis_indexer) {
470
        if (i) {
M
Megvii Engine Team 已提交
471
            prop->add_dep_type_existing_var(i->idx.node(), DT::VALUE_ALLOW_EMPTY);
472 473 474 475 476
        }
    }
    return prop;
}

M
Megvii Engine Team 已提交
477 478 479
template <class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>::add_input_layout_constraint() {
    auto check_cont1 = [](const TensorLayout& ly) {
480 481 482 483 484 485 486
        return ly.collapse_contiguous().ndim == 1;
    };
    this->input(1)->add_layout_constraint(check_cont1);
}

/* ==================== MultiAxisVec misc ==================== */

M
Megvii Engine Team 已提交
487 488
MGB_IMPL_FANCY_INDEXING_OPR_GET(IndexingMultiAxisVec, "indexing_multi_axis_vec", false,
                                output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
489
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
490
        IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false,
M
Megvii Engine Team 已提交
491
        output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
492 493 494
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
        IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);

495 496
IndexingSetMultiAxisVec::NodeProp* IndexingSetMultiAxisVec::do_make_node_prop() const {
    auto prop = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
497
    prop->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
498 499 500
    return prop;
}

501
#if MGB_ENABLE_GRAD
502 503 504 505 506
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
    if (wrt_idx)
        return InvalidGrad::make(opr, wrt_idx);

    return IndexingIncrMultiAxisVec::make(
M
Megvii Engine Team 已提交
507 508 509
                   SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
                   opr.index_desc())
            .node();
510
}
511
#endif
512

513
#if MGB_ENABLE_GRAD
514 515 516 517
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
    if (wrt_idx >= 2)
        return InvalidGrad::make(opr, wrt_idx);
    if (wrt_idx == 0) {
M
Megvii Engine Team 已提交
518 519 520 521
        return IndexingSetMultiAxisVec::make(
                       out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
                       opr.index_desc())
                .node();
522 523 524
    }
    return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
}
525
#endif
526

527
#if MGB_ENABLE_GRAD
528 529 530 531 532 533 534 535
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
    if (wrt_idx >= 2)
        return InvalidGrad::make(opr, wrt_idx);
    if (wrt_idx == 0) {
        return out_grad.at(0);
    }
    return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
}
536
#endif
537 538 539

/* ============================= Mesh Indexing ============================ */

M
Megvii Engine Team 已提交
540 541 542 543
MGB_IMPL_FANCY_INDEXING_OPR_GET(MeshIndexing, "mesh_indexing", false,
                                output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_FANCY_INDEXING_OPR_GET(BatchedMeshIndexing, "batched_mesh_indexing", false,
                                output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
544

545
#if MGB_ENABLE_GRAD
546 547 548 549 550 551 552 553 554
MGB_IMPL_OPR_GRAD(MeshIndexing) {
    if (wrt_idx != 0) {
        return InvalidGrad::make(opr, wrt_idx);
    }
    return IncrMeshIndexing::make(
                   SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
                   opr.index_desc())
            .node();
}
555 556
#endif

557
#if MGB_ENABLE_GRAD
558 559 560 561 562 563 564 565 566
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
    if (wrt_idx != 0) {
        return InvalidGrad::make(opr, wrt_idx);
    }
    return BatchedIncrMeshIndexing::make(
                   SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
                   opr.index_desc())
            .node();
}
567
#endif
568 569 570

/* ========================= IncrMeshIndexing ========================= */

M
Megvii Engine Team 已提交
571
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", false);
572

573
#if MGB_ENABLE_GRAD
574 575 576 577 578 579 580 581 582
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
    if (wrt_idx > 2) {
        return opr::InvalidGrad::make(opr, wrt_idx);
    }
    if (wrt_idx == 0) {
        return out_grad.at(0);
    }
    return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
583
#endif
584

M
Megvii Engine Team 已提交
585 586
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
        BatchedIncrMeshIndexing, "batched_incr_mesh_indexing", false);
587
#if MGB_ENABLE_GRAD
588 589 590 591 592 593 594 595 596
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
    if (wrt_idx > 2) {
        return opr::InvalidGrad::make(opr, wrt_idx);
    }
    if (wrt_idx == 0) {
        return out_grad.at(0);
    }
    return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
}
597
#endif
598 599 600 601

/* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);

602
#if MGB_ENABLE_GRAD
603 604 605 606 607 608
MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
    if (wrt_idx >= 2) {
        return opr::InvalidGrad::make(opr, wrt_idx);
    }
    if (wrt_idx == 0) {
        return SetMeshIndexing::make(
M
Megvii Engine Team 已提交
609
                       out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
610 611 612 613 614 615
                       opr.index_desc())
                .node();
    } else {
        return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
    }
}
616
#endif
617

M
Megvii Engine Team 已提交
618 619
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
        BatchedSetMeshIndexing, "batched_set_mesh_indexing", false);
620
#if MGB_ENABLE_GRAD
621 622 623 624 625 626
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
    if (wrt_idx > 2) {
        return opr::InvalidGrad::make(opr, wrt_idx);
    }
    if (wrt_idx == 0) {
        return BatchedSetMeshIndexing::make(
M
Megvii Engine Team 已提交
627
                       out_grad.at(0), SymbolVar{opr.input(1)}.fill_retain_dtype(0),
628 629 630
                       opr.index_desc())
                .node();
    } else {
M
Megvii Engine Team 已提交
631
        return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
632 633
    }
}
634
#endif
635 636

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}