batch_norm.cpp 2.9 KB
Newer Older
1 2 3 4
/**
 * \file imperative/src/impl/ops/batch_norm.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "megbrain/opr/dnn/batch_norm.h"
13
#include "../op_trait.h"
M
Megvii Engine Team 已提交
14
#include "megbrain/imperative/ops/autogen.h"
15 16 17 18 19 20 21 22

namespace mgb {
namespace imperative {

namespace {

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::BatchNorm>();
23
    return BatchNorm::make(node->param());
24 25
}

M
Megvii Engine Team 已提交
26
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
27 28
    auto&& bn_opr = def.cast_final_safe<BatchNorm>();
    size_t nr_inp = inputs.size();
M
Megvii Engine Team 已提交
29 30 31
    mgb_assert(
            nr_inp == 3 || nr_inp == 5,
            "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
32
    OperatorNodeConfig config{bn_opr.make_name()};
33 34
    if (nr_inp == 3) {
        return opr::BatchNorm::make(
M
Megvii Engine Team 已提交
35 36 37
                       inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0]
                .node()
                ->owner_opr();
38 39
    } else {
        return opr::BatchNorm::make(
M
Megvii Engine Team 已提交
40 41 42 43
                       inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
                       bn_opr.param(), config)[0]
                .node()
                ->owner_opr();
44 45 46
    }
}

47
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
48
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
49 50
    auto&& op_def = def.cast_final_safe<BatchNorm>();
    size_t nr_inp = inputs.size();
M
Megvii Engine Team 已提交
51 52 53
    mgb_assert(
            nr_inp == 3 || nr_inp == 5,
            "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
54
    // need running mean/variance
55
    bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING;
M
Megvii Engine Team 已提交
56
    size_t nr_out = need_stat ? 6 : 4;
57 58 59
    SmallVector<LogicalTensorDesc> out_shapes(nr_out);
    auto&& i0 = inputs[0];
    auto&& i1 = inputs[1];
60
    // [running_mean, running_var,] save_mean, save_var
M
Megvii Engine Team 已提交
61
    for (size_t i = 0; i < nr_out - 2; ++i) {
62 63
        out_shapes[i] = {i1.layout, i1.comp_node};
    }
M
Megvii Engine Team 已提交
64 65 66 67
    out_shapes[nr_out - 2] = {
            TensorLayout({0}, dtype::Byte()), i0.comp_node};  // reserve
    out_shapes[nr_out - 1] = {i0.layout, i0.comp_node};       // output
    return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0};
68 69 70
}

OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
M
Megvii Engine Team 已提交
71 72 73 74 75
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
        .fallback();
}  // anonymous namespace
76 77 78 79 80

}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}