broadcast.cpp 8.2 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/broadcast.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"

15 16
#include "megbrain/graph/helper.h"

17 18 19 20 21
#include "../op_trait.h"

namespace mgb {
namespace imperative {

22
namespace broadcast {
23 24 25 26 27 28

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    node_->cast_final_safe<opr::Broadcast>();
    return Broadcast::make();
}

M
Megvii Engine Team 已提交
29
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
30
    auto&& op = def.cast_final_safe<Broadcast>();
31 32
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
33
    OperatorNodeConfig config{op.make_name()};
34
    return opr::Broadcast::make(inputs[0], inputs[1], config);
35 36
}

M
Megvii Engine Team 已提交
37
bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) {
38 39 40 41
    size_t src_ndim = src_shape.ndim, tar_ndim = tar_shape.ndim;
    if (src_ndim > tar_ndim) {
        return false;
    }
42
    size_t min_ndim = src_ndim;
43 44 45 46 47 48 49 50 51
    for (size_t i = 0; i < min_ndim; ++i) {
        if (src_shape[src_ndim - i - 1] != 1 &&
            src_shape[src_ndim - i - 1] != tar_shape[tar_ndim - i - 1]) {
            return false;
        }
    }
    return true;
}

52
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
53
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
54 55 56 57 58
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp = inputs[1];

59
    TensorShape out_shape;
60
    if (tshp.layout.ndim == 0 || tshp.value.empty()) {
61 62
        out_shape.ndim = 0;
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
63 64
    }
    mgb_assert(
M
Megvii Engine Team 已提交
65 66 67
            tshp.layout.ndim == 1,
            "target shape of Broadcast expects ndim=1; got ndim=%lu actually",
            tshp.layout.ndim);
68 69

    size_t target_ndim = tshp.layout.shape[0];
70
    out_shape.ndim = target_ndim;
71
    auto* ptr = tshp.value.ptr<dt_int32>();
72
    for (size_t i = 0; i < target_ndim; ++i) {
73
        out_shape[i] = ptr[i];
74
    }
M
Megvii Engine Team 已提交
75 76 77 78
    mgb_assert(
            valid_broadcast(src.layout, out_shape),
            "the input shape %s can not be broadcasted to target shape %s",
            src.layout.to_string().c_str(), out_shape.to_string().c_str());
79

80
    return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
81 82
}

83
SmallVector<TensorPtr> apply_on_physical_tensor(
84 85
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
86 87 88 89 90 91 92 93 94 95 96 97
    def.cast_final_safe<Broadcast>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp_nd = inputs[1];
    auto slayout = src->layout();

    TensorShape tshp;
    cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
    TensorLayout tlayout = slayout.broadcast(tshp);
    // memory forward
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
98 99
}

100
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
M
Megvii Engine Team 已提交
101 102 103
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
104
        .apply_on_physical_tensor(apply_on_physical_tensor)
M
Megvii Engine Team 已提交
105 106
        .fallback();
}  // namespace broadcast
107 108 109

namespace reshape {

M
Megvii Engine Team 已提交
110
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
111 112
    auto&& op = static_cast<const Reshape&>(def);
    mgb_assert(inputs.size() == 2);
113 114
    OperatorNodeConfig config{op.make_name()};
    return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
115 116 117
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
118
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
119 120 121 122 123 124
    auto&& op = def.cast_final_safe<Reshape>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp = inputs[1];

125
    TensorShape out_shape;
126
    if (tshp.layout.ndim == 0 || tshp.value.empty()) {
127 128
        out_shape.ndim = 0;
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
129 130
    }
    mgb_assert(
M
Megvii Engine Team 已提交
131 132 133
            tshp.layout.ndim == 1,
            "target shape of Reshape expects ndim=1; got ndim=%lu actually",
            tshp.layout.ndim);
134

135 136 137 138
    if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
    }

139
    size_t target_ndim = tshp.layout.shape[0];
140
    out_shape.ndim = target_ndim;
141 142
    auto* ptr = tshp.value.ptr<dt_int32>();
    for (size_t i = 0; i < target_ndim; ++i) {
143
        out_shape[i] = ptr[i];
144 145 146
    }

    if (src.layout.ndim == 0) {
147
        return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
148 149 150
    }

    if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
151 152
        mgb_assert(out_shape[op.axis] == -1);
        out_shape[op.axis] = 1;
M
Megvii Engine Team 已提交
153 154 155 156
        mgb_assert(
                src.layout.total_nr_elems() % out_shape.total_nr_elems() == 0,
                "can not reshape from %s to %s", src.layout.to_string().c_str(),
                out_shape.to_string().c_str());
157
        out_shape[op.axis] = src.layout.total_nr_elems() / out_shape.total_nr_elems();
158
    } else {
M
Megvii Engine Team 已提交
159 160 161 162
        mgb_assert(
                src.layout.total_nr_elems() == out_shape.total_nr_elems(),
                "can not reshape from %s to %s", src.layout.to_string().c_str(),
                out_shape.to_string().c_str());
163
    }
164
    return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
165 166
}

167
SmallVector<TensorPtr> apply_on_physical_tensor(
168 169
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
170 171 172 173 174 175 176 177 178 179 180 181 182 183
    auto&& op_def = def.cast_final_safe<Reshape>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
    auto&& src = inputs[0];
    auto&& tshp_nd = inputs[1];
    auto slayout = src->layout();

    TensorShape tshp;
    cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
    if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
        mgb_assert(tshp[op_def.axis] == -1);
        tshp[op_def.axis] = 1;
        tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
    }
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    TensorLayout tlayout;
    mgb_assert(slayout.try_reshape(tlayout, tshp));
    return {Tensor::make(src->blob(), src->offset(), tlayout)};
}

SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    auto&& op_def = def.cast_final_safe<Reshape>();
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    layout_checker[0] = [&](const TensorLayout& layout) {
        TensorShape tshp;
        TensorLayout ret;
        cg::copy_tensor_value_to_shape(
                tshp, inputs[1]->get_value().proxy_to_default_cpu());
        if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
            mgb_assert(tshp[op_def.axis] == -1);
            tshp[op_def.axis] = 1;
            tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
        }
        if (layout.try_reshape(ret, tshp)) {
            return true;
        } else {
            return false;
        }
    };
    return layout_checker;
210 211
}

212
OP_TRAIT_REG(Reshape, Reshape)
M
Megvii Engine Team 已提交
213 214
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
215
        .apply_on_physical_tensor(apply_on_physical_tensor)
216
        .get_input_layout_constraint(get_input_layout_constraint)
M
Megvii Engine Team 已提交
217 218
        .fallback();
}  // namespace reshape
219 220 221 222 223

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}