batch_norm.cpp 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * \file imperative/src/impl/ops/batch_norm.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/ops/batch_norm.h"
#include "../op_trait.h"

namespace mgb {
namespace imperative {

namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::BatchNorm>();
    auto&& param = node->param();
    return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, 
                           param.avg_factor, param.scale, param.bias);
}

cg::OperatorNodeBase* apply_on_var_node(
        const OpDef& def,
        const VarNodeArray& inputs) {
    auto&& bn_opr = def.cast_final_safe<BatchNorm>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 3 ||nr_inp == 5,
              "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
    if (nr_inp == 3) {
        return opr::BatchNorm::make(
            inputs[0], inputs[1], inputs[2],
            {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
            .node()->owner_opr();
    } else {
        return opr::BatchNorm::make(
            inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], 
            {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
            .node()->owner_opr();
    }
}

47
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        const OpDef& def,
        const SmallVector<LogicalTensorDesc>& inputs) {
    auto&& op_def = def.cast_final_safe<BatchNorm>();
    size_t nr_inp = inputs.size();
    mgb_assert(nr_inp == 3 ||nr_inp == 5,
              "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
    // need running mean/variance
    bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING;
    size_t nr_out = need_stat? 5 : 3;
    SmallVector<LogicalTensorDesc> out_shapes(nr_out);
    auto&& i0 = inputs[0];
    auto&& i1 = inputs[1];
    size_t i = 0;
    if (!need_stat) {
        out_shapes[0] = out_shapes[1] = {TensorLayout({0}, i0.layout.dtype, i0.layout.format), i0.comp_node};
        i = 2;
    }
    for (; i < nr_out-1; ++ i) {
        out_shapes[i] = {i1.layout, i1.comp_node};
    }
    out_shapes[nr_out-1] = {i0.layout, i0.comp_node};
69
    return {out_shapes, true};
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
}

OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
    .make_from_op_node(make_from_op_node)
    .apply_on_var_node(apply_on_var_node)
    .infer_output_attrs_fallible(infer_output_attrs_fallible)
    .fallback();
} // anonymous namespace

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm);

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}