indexing.cpp 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
/**
 * \file imperative/src/impl/ops/indexing.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

#include "megbrain/opr/indexing.h"

namespace mgb {
namespace imperative {

namespace {
namespace indexing_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& input_descs) {
    auto& op = def.cast_final_safe<IndexingOneHot>();

    mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");

    auto comp_node = input_descs[0].comp_node;
    TensorLayout src = input_descs[0].layout,
            index = input_descs[1].layout;

    mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");

    if (!src.ndim) {
        return {{{{{}, src.dtype}, comp_node}}, false};
    }

    mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
    mgb_assert(src.is_contiguous(), "src should be contiguous");
    mgb_assert(op.axis >= 0 && op.axis < src.ndim, "axis %d not exists in src", op.axis);

    TensorLayout dst = src;
    dst.shape[op.axis] = 1;
    dst.init_contiguous_stride();

    if (!index.ndim) {
        return {{{dst, comp_node}}, false};
    }

    mgb_assert(index.is_contiguous(), "index should be all contiguous");
    mgb_assert(index.eq_shape(src.remove_axis(op.axis)), "index shape doesn't match src");

    return {{{dst, comp_node}}, true};
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const IndexingOneHot&>(def);
    mgb_assert(inputs.size() == 2);
    OperatorNodeConfig config{op.make_name()};
    return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
}

OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .apply_on_var_node(apply_on_var_node)
    .fallback();

}  // namespace indexing_one_hot
}  // anonymous namespace
}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}