indexing.cpp 6.0 KB
Newer Older
1
#include "../dnn_op_helper.h"
2 3 4 5 6
#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

#include "megbrain/opr/indexing.h"
7
#include "megdnn/oprs/general.h"
8 9 10 11 12 13 14 15

namespace mgb {
namespace imperative {

namespace {
namespace indexing_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
16
        const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
17
    auto&& op = def.cast_final_safe<IndexingOneHot>();
18 19
    mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");
    auto comp_node = input_descs[0].comp_node;
M
Megvii Engine Team 已提交
20
    TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
21 22 23 24 25 26 27 28 29

    mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");

    if (!src.ndim) {
        return {{{{{}, src.dtype}, comp_node}}, false};
    }

    mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
    mgb_assert(src.is_contiguous(), "src should be contiguous");
M
Megvii Engine Team 已提交
30
    mgb_assert(
31 32 33 34 35 36 37
            -static_cast<int>(src.ndim) <= op.axis &&
                    op.axis < static_cast<int>(src.ndim),
            "axis %d not exists in src", op.axis);
    int real_axis = static_cast<int>(op.axis);
    if (real_axis < 0) {
        real_axis += static_cast<int>(src.ndim);
    }
38
    TensorLayout dst = src;
39
    dst.shape[real_axis] = 1;
40 41 42 43 44 45 46
    dst.init_contiguous_stride();

    if (!index.ndim) {
        return {{{dst, comp_node}}, false};
    }

    mgb_assert(index.is_contiguous(), "index should be all contiguous");
M
Megvii Engine Team 已提交
47
    mgb_assert(
48 49
            index.eq_shape(src.remove_axis(real_axis)),
            "index shape doesn't match src");
50 51 52 53
    return {{{dst, comp_node}}, true};
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
54
    auto&& op = def.cast_final_safe<IndexingOneHot>();
55
    mgb_assert(inputs.size() == 2);
56 57 58 59
    int real_axis = static_cast<int>(op.axis);
    if (real_axis < 0) {
        real_axis += static_cast<int>(op.ndim);
    }
60
    OperatorNodeConfig config{op.make_name()};
61 62 63 64 65 66 67 68 69
    return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config);
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, SmallVector<TensorPtr> inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op = def.cast_final_safe<IndexingOneHot>();
    auto&& inp = inputs[0];
    auto&& index = inputs[1];
70 71
    auto&& layout = inp->layout();
    auto&& index_layout = index->layout();
72 73 74 75 76 77 78 79
    int real_axis = static_cast<int>(op.axis);
    if (real_axis < 0) {
        real_axis += static_cast<int>(layout.ndim);
    }
    mgb_assert(
            0 <= real_axis && real_axis < static_cast<int>(layout.ndim),
            "Dimension out of range (expected to be in range of [%d, %d], but got %d)",
            0, static_cast<int>(layout.ndim) - 1, op.axis);
80 81 82 83
    DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node(), real_axis);
    auto tlayout = dnn_op.deduce_layout(layout, index_layout);
    auto out = Tensor::make(tlayout, inp->comp_node());
    dnn_op.exec_with_ws(inp, index, out);
84
    return {out};
85 86 87
}

OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
M
Megvii Engine Team 已提交
88 89
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_var_node(apply_on_var_node)
90
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
91
        .fallback();
92
}  // namespace indexing_one_hot
93 94 95 96 97 98 99

namespace indexing_set_one_hot {

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
    mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs");
    auto comp_node = input_descs[0].comp_node;
100 101
    auto&& src = input_descs[0].layout;
    auto&& index = input_descs[1].layout;
102 103 104 105
    mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
    if (!src.ndim) {
        return {{{{{}, src.dtype}, comp_node}}, false};
    }
106
    return {{{{src, src.dtype}, comp_node}}, true};
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
}

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& op = static_cast<const IndexingSetOneHot&>(def);
    mgb_assert(inputs.size() == 3);
    int real_axis = static_cast<int>(op.axis);
    if (real_axis < 0) {
        real_axis += static_cast<int>(op.ndim);
    }
    OperatorNodeConfig config{op.make_name()};
    return opr::IndexingSetOneHot::make(
            inputs[0], inputs[1], inputs[2], real_axis, config);
}

SmallVector<TensorPtr> apply_on_physical_tensor(
        const OpDef& def, SmallVector<TensorPtr> inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    auto&& op = def.cast_final_safe<IndexingSetOneHot>();
    auto&& inp = inputs[0];
    auto&& index = inputs[1];
    auto&& sub = inputs[2];
    TensorLayout layout = inp->layout();
    mgb_assert(layout.is_contiguous());
    int real_axis = static_cast<int>(op.axis);
    if (real_axis < 0) {
        real_axis += static_cast<int>(layout.ndim);
    }
134
    DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node(), real_axis);
135 136
    TensorPtr out = Tensor::make(layout, inp->comp_node());
    out->dev_tensor().copy_from_fixlayout(inp->dev_tensor());
137
    dnn_op.exec_with_ws(out, index, sub);
138 139
    return {out};
}
140 141 142 143 144 145 146
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    layout_checker[0] = layout_checker[1] = layout_checker[2] =
            [](const TensorLayout& layout) { return layout.is_contiguous(); };
    return layout_checker;
}
147 148 149 150

OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .apply_on_var_node(apply_on_var_node)
151
        .get_input_layout_constraint(get_input_layout_constraint)
152 153 154 155
        .apply_on_physical_tensor(apply_on_physical_tensor)
        .fallback();
}  // namespace indexing_set_one_hot

156 157 158 159 160
}  // anonymous namespace
}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}