broadcast.cpp 5.5 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/broadcast.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"

15 16 17 18 19
#include "../op_trait.h"

namespace mgb {
namespace imperative {

20
namespace broadcast {
21 22 23 24 25 26

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    node_->cast_final_safe<opr::Broadcast>();
    return Broadcast::make();
}

27
auto apply_on_var_node(
28 29
        const OpDef& def,
        const VarNodeArray& inputs) {
30
    auto&& op = def.cast_final_safe<Broadcast>();
31 32
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
33
    OperatorNodeConfig config{op.make_name()};
34
    return opr::Broadcast::make(inputs[0], inputs[1], config);
35 36 37 38 39 40 41 42
}

bool valid_broadcast(const TensorShape& src_shape,
                     const TensorShape& tar_shape) {
    size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
    if (src_ndim > tar_ndim) {
        return false;
    }
43
    size_t min_ndim = src_ndim;
44 45 46 47 48 49 50 51 52
    for (size_t i = 0; i < min_ndim; ++i) {
        if (src_shape[src_ndim - i - 1] != 1 &&
            src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
            return false;
        }
    }
    return true;
}

53
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
54 55 56 57 58 59 60 61
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs) {
    def.cast_final_safe<Broadcast>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp = inputs[1];

62
    TensorShape out_shape;
63
    if (tshp.layout.ndim == 0 || tshp.value.empty()) {
64 65
        out_shape.ndim = 0;
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
66 67
    }
    mgb_assert(
68 69
        tshp.layout.ndim == 1,
        "target shape of Broadcast expects ndim=1; got ndim=%lu actually",
70 71 72
        tshp.layout.ndim);

    size_t target_ndim = tshp.layout.shape[0];
73
    out_shape.ndim = target_ndim;
74
    auto* ptr = tshp.value.ptr<dt_int32>();
75
    for (size_t i = 0; i < target_ndim; ++i) {
76
        out_shape[i] = ptr[i];
77
    }
78
    mgb_assert(valid_broadcast(src.layout, out_shape),
79
               "the input shape %s can not be broadcasted to target shape %s", 
80 81
               src.layout.to_string().c_str(),
               out_shape.to_string().c_str());
82

83
    return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
84 85 86 87 88 89 90
}

OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
    .make_from_op_node(make_from_op_node)
    .apply_on_var_node(apply_on_var_node)
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .fallback();
91 92 93 94 95 96 97 98 99
} // broadcast

namespace reshape {

auto apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
    auto&& op = static_cast<const Reshape&>(def);
    mgb_assert(inputs.size() == 2);
100 101
    OperatorNodeConfig config{op.make_name()};
    return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
102 103 104 105 106 107 108 109 110 111 112
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op = def.cast_final_safe<Reshape>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp = inputs[1];

113
    TensorShape out_shape;
114
    if (tshp.layout.ndim == 0 || tshp.value.empty()) {
115 116
        out_shape.ndim = 0;
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
117 118 119
    }
    mgb_assert(
        tshp.layout.ndim == 1,
120
        "target shape of Reshape expects ndim=1; got ndim=%lu actually",
121 122
        tshp.layout.ndim);

123 124 125 126
    if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
    }

127
    size_t target_ndim = tshp.layout.shape[0];
128
    out_shape.ndim = target_ndim;
129 130
    auto* ptr = tshp.value.ptr<dt_int32>();
    for (size_t i = 0; i < target_ndim; ++i) {
131
        out_shape[i] = ptr[i];
132 133 134
    }

    if (src.layout.ndim == 0) {
135
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
136 137 138
    }

    if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
139 140 141
        mgb_assert(out_shape[op.axis] == -1);
        out_shape[op.axis] = 1;
        mgb_assert(src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
142 143
            "can not reshape from %s to %s",
            src.layout.to_string().c_str(),
144 145
            out_shape.to_string().c_str());
        out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
146
    } else {
147
        mgb_assert(src.layout.total_nr_elems() == out_shape.total_nr_elems(),
148 149
            "can not reshape from %s to %s",
            src.layout.to_string().c_str(),
150
            out_shape.to_string().c_str());
151
    }
152
    return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
153 154 155 156 157 158 159
}

OP_TRAIT_REG(Reshape, Reshape)
    .apply_on_var_node(apply_on_var_node)
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .fallback();
} // reshape
160 161 162 163 164

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}