/** * \file imperative/src/impl/ops/batch_norm.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/dnn/batch_norm.h" #include "../op_trait.h" #include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/subgraph_detail.h" #include "megbrain/tensor.h" namespace mgb { namespace imperative { namespace { EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) { Subgraph::Builder builder{ [](std::shared_ptr op, SmallVector inputs, size_t nr_outputs) { auto [outputs, validated] = OpDef::infer_output_attrs_fallible(*op, inputs); mgb_assert(outputs.size() == nr_outputs, "nr_outputs mismatch"); return outputs; }}; auto f = [&](auto&& op, auto... args) { return builder.write_expr( op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; }; auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); auto sub = Elemwise::make(Elemwise::Mode::SUB); auto mul = Elemwise::make(Elemwise::Mode::MUL); auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); auto floor_div = Elemwise::make(Elemwise::Mode::FLOOR_DIV); auto broadcast = Broadcast::make(); auto c = [&](TensorPtr tensor, DType dtype) { auto result = builder.write_constant( tensor, {TensorLayout{tensor->dtype()}, tensor->comp_node()}); if (tensor->dtype() != dtype) { result = f(TypeCvt::make(dtype), result); } return result; }; auto ci = [&](megdnn::dt_int32 value) { return c(Tensor::make_scalar(DTypeScalar(value), device), dtype::Int32()); }; auto cf = [&](megdnn::dt_float32 value) { return c(Tensor::make_scalar(DTypeScalar(value), device), dtype); }; auto desc = LogicalTensorDesc{TensorLayout{dtype}, device}; auto x = builder.write_input(desc); auto y_grad = builder.write_input(desc); auto save_mean = builder.write_input(desc); auto save_invstd = builder.write_input(desc); auto weight = builder.write_input(desc); auto reserved = builder.write_input(desc); MGB_MARK_USED_VAR(reserved); // assert x.ndim == 4 auto input_shape = f(GetVarShape::make(), x); auto channels = f(GetVarShape::make(1), x); auto reduce_shape = f(Concat::make(0, device), ci(1), channels, ci(1), ci(1)); auto input_elems = f(prod, input_shape); auto reduce_size = f(floor_div, input_elems, channels); auto reduce_size_f = f(TypeCvt::make(dtype), reduce_size); auto mean = f(broadcast, save_mean, input_shape); auto invstd = save_invstd; auto norm = f(div, cf(1), reduce_size_f); auto output_grad_sum = f(sum, y_grad, reduce_shape); auto dot_p = f(sum, f(mul, y_grad, f(sub, x, mean)), reduce_shape); auto mean_grad = f(broadcast, f(mul, output_grad_sum, norm), input_shape); auto proj_scale = f(broadcast, f(mul, f(mul, dot_p, norm), f(mul, invstd, invstd)), input_shape); auto grad_scale = f( mul, f(broadcast, invstd, input_shape), f(broadcast, weight, input_shape)); auto proj = f(mul, f(sub, x, mean), proj_scale); auto x_grad = f(mul, f(sub, f(sub, y_grad, proj), mean_grad), grad_scale); auto weight_grad = f(mul, dot_p, invstd); auto bias_grad = output_grad_sum; builder.add_outputs({weight_grad, bias_grad, x_grad}); auto bn_backward = builder.encode(); return bn_backward; } namespace bn { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); return BatchNorm::make(node->param()); } cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& bn_opr = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert( nr_inp == 3 || nr_inp == 5, "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); OperatorNodeConfig config{bn_opr.make_name()}; if (nr_inp == 3) { return opr::BatchNorm::make( inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0] .node() ->owner_opr(); } else { return opr::BatchNorm::make( inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0] .node() ->owner_opr(); } } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert( nr_inp == 3 || nr_inp == 5, "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); // need running mean/variance bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; size_t nr_out = need_stat ? 6 : 4; SmallVector out_shapes(nr_out); auto&& i0 = inputs[0]; auto&& i1 = inputs[1]; // [running_mean, running_var,] save_mean, save_var for (size_t i = 0; i < nr_out - 2; ++i) { out_shapes[i] = {i1.layout, i1.comp_node}; } out_shapes[nr_out - 2] = { TensorLayout({0}, dtype::Byte()), i0.comp_node}; // reserve out_shapes[nr_out - 1] = {i0.layout, i0.comp_node}; // output return {out_shapes, out_shapes[nr_out - 1].layout.ndim != 0}; } OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace bn namespace bn_backward { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); return BatchNormBackward::make(node->param()); } VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto& op = def.cast_final_safe(); cg::SymbolVar x, y_grad, save_mean, save_variance, weight, reserve; x = inputs[0]; y_grad = inputs[1]; save_mean = inputs[2]; save_variance = inputs[3]; weight = inputs[4]; if (inputs.size() == 6) { reserve = inputs[5]; } return opr::BatchNormBackward::make( x, y_grad, save_mean, save_variance, weight, reserve, op.param())[0] .node() ->owner_opr() ->usable_output(); } EncodedSubgraph make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { def.cast_final_safe(); size_t nr_inputs = 6; size_t nr_outputs = 3; mgb_assert(inputs.size() == nr_inputs); mgb_assert(input_requires_grad.size() == nr_inputs); mgb_assert(output_has_grad.size() == nr_outputs); auto dtype = inputs[0].layout.dtype; auto device = inputs[0].comp_node; auto bn_backward = generate_batchnorm_backward_graph(dtype, device); auto bn_double_backward = subgraph_detail::make_backward_graph_from_forward( bn_backward, inputs, input_requires_grad, output_has_grad); return bn_double_backward; } OP_TRAIT_REG(BatchNormBackward, BatchNormBackward, opr::BatchNormBackward) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .make_backward_graph(make_backward_graph) .fallback(); } // namespace bn_backward } // anonymous namespace } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}