elemwise.cpp 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/src/impl/ops/elemwise.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include "../op_trait.h"

namespace mgb {
namespace imperative {

namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Elemwise>();
    return Elemwise::make(node->param().mode);
}

cg::OperatorNodeBase* apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
    auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
    return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
}

34
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
35 36 37
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<Elemwise>();
38
    auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
39 40 41 42 43 44
    mgb_assert(inputs.size() == trait.arity,
               "%s expects %u inputs; got %zu actually", trait.name,
               trait.arity, inputs.size());
    TensorShapeArray inp_shapes;
    DType out_dt;
    CompNode out_cn;
45
    for (size_t i = 0; i < inputs.size(); ++ i) {
46 47 48 49 50 51 52 53 54 55 56 57 58 59
        auto &&t = inputs[i];
        if (!i) {
            out_cn = t.comp_node;
            out_dt = t.layout.dtype;
        } else {
            mgb_assert(t.comp_node == out_cn);
            mgb_assert(t.layout.dtype == out_dt);
        }
        if (t.layout.ndim > 0) {
            inp_shapes.push_back(t.layout);
        } else {
            TensorLayout out_layout;
            out_layout.ndim = 0;
            out_layout.dtype = out_dt;
60
            return {{{out_layout, out_cn}}, true};
61 62 63 64
        }
    }

    auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
65
    return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
66 67 68 69 70 71 72 73 74 75 76 77 78
}

OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
    .make_from_op_node(make_from_op_node)
    .apply_on_var_node(apply_on_var_node)
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .fallback();
} // anonymous namespace

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}