batch_norm.cpp 13.5 KB
Newer Older
1 2
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/graph/grad_impl.h"
3
#include "megbrain/opr/basic_arith.h"
M
Megvii Engine Team 已提交
4
#include "megbrain/opr/io.h"
5
#include "megbrain/opr/tensor_manip.h"
6 7 8 9 10 11

#include "../internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;

M
Megvii Engine Team 已提交
12 13 14 15
namespace mgb {
namespace opr {
namespace intl {
template <>
16 17 18 19
struct AutoAddWorkspaceNeedLimitGetter<megdnn::BNForward> {
    static constexpr bool val = true;
};

M
Megvii Engine Team 已提交
20
template <>
21 22 23
struct AutoAddWorkspaceNeedLimitGetter<megdnn::BNBackward> {
    static constexpr bool val = true;
};
M
Megvii Engine Team 已提交
24 25 26
}  // namespace intl
}  // namespace opr
}  // namespace mgb
27 28 29

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormForward);

M
Megvii Engine Team 已提交
30 31 32 33 34 35 36 37
BatchNormForward::BatchNormForward(
        VarNode* x, VarNode* scale, VarNode* bias, VarNode* mean, VarNode* variance,
        const Param& param, const OperatorNodeConfig& config)
        : Super{x->owner_graph(),
                config,
                "batch_norm",
                {x, scale, bias, mean, variance}} {
    if (owner_graph()->options().no_force_inplace) {
38 39 40
        m_force_inplace = false;
    }

M
Megvii Engine Team 已提交
41
    if (m_force_inplace && param.fwd_mode == Param::FwdMode::TRAINING) {
42 43
        auto check_dest = [&](VarNode* dest) {
            auto dest_opr = dest->owner_opr();
M
Megvii Engine Team 已提交
44 45 46
            mgb_throw_if(
                    !(dest_opr->same_type<SharedDeviceTensor>() ||
                      dest_opr->same_type<VolatileSharedDeviceTensor>()),
47
                    GraphError,
48 49
                    "mean and variance in training mode BatchNorm must be"
                    "SharedDeviceTensor or VolatileSharedDeviceTensor;"
50
                    "got %s{%s} actually",
51 52 53 54 55
                    dest_opr->cname(), dest_opr->dyn_typeinfo()->name);
        };
        check_dest(mean);
        check_dest(variance);
    }
56 57 58 59 60

    init_megdnn_opr(*this, param);

    add_input({x, scale, bias, mean, variance});

M
Megvii Engine Team 已提交
61
    output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);  // reserve
62 63
    output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
    // running mean/var
M
Megvii Engine Team 已提交
64
    if (param.fwd_mode == Param::FwdMode::INFERENCE) {
M
Megvii Engine Team 已提交
65
        auto mark_empty_var = [&](VarNode* var) {
M
Megvii Engine Team 已提交
66
            var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
M
Megvii Engine Team 已提交
67
                    .add_flag(VarNode::Flag::VOLATILE_CONTENT);
M
Megvii Engine Team 已提交
68 69 70 71
        };
        mark_empty_var(output(0));
        mark_empty_var(output(1));
    } else if (m_force_inplace) {
M
Megvii Engine Team 已提交
72 73
        output(0)->set_fwd_in2out_writable_force(input(3)).add_flag(
                VarNode::Flag::NO_MEM_RECLAIM);
74

M
Megvii Engine Team 已提交
75 76
        output(1)->set_fwd_in2out_writable_force(input(4)).add_flag(
                VarNode::Flag::NO_MEM_RECLAIM);
77
    }
78 79
}

M
Megvii Engine Team 已提交
80 81 82 83
BatchNormForward::BatchNormForward(
        VarNode* x, VarNode* scale, VarNode* bias, const Param& param,
        const OperatorNodeConfig& config)
        : Super{x->owner_graph(), config, "batch_norm", {x, scale, bias}} {
84 85 86
    init_megdnn_opr(*this, param);

    add_input({x, scale, bias});
M
Megvii Engine Team 已提交
87
    output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);  // reserve
88
    output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
M
Megvii Engine Team 已提交
89
    auto mark_empty_var = [&](VarNode* var) {
90
        var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
M
Megvii Engine Team 已提交
91
                .add_flag(VarNode::Flag::VOLATILE_CONTENT);
92 93 94 95 96
    };
    mark_empty_var(output(0));
    mark_empty_var(output(1));
}

M
Megvii Engine Team 已提交
97 98 99 100 101 102 103 104
SymbolVarArray BatchNormForward::make(
        SymbolVar x, SymbolVar scale, SymbolVar bias, SymbolVar mean,
        SymbolVar variance, const Param& param, const OperatorNodeConfig& config) {
    auto&& out = x.node()->owner_graph()
                         ->insert_opr(std::make_unique<BatchNormForward>(
                                 x.node(), scale.node(), bias.node(), mean.node(),
                                 variance.node(), param, config))
                         ->output();
105 106 107 108 109 110 111
    SymbolVarArray ret(out.size());
    for (size_t i = 0; i < ret.size(); i++) {
        ret[i] = out[i];
    }
    return ret;
}

M
Megvii Engine Team 已提交
112 113 114 115 116 117 118
SymbolVarArray BatchNormForward::make(
        SymbolVar x, SymbolVar scale, SymbolVar bias, const Param& param,
        const OperatorNodeConfig& config) {
    auto&& out = x.node()->owner_graph()
                         ->insert_opr(std::make_unique<BatchNormForward>(
                                 x.node(), scale.node(), bias.node(), param, config))
                         ->output();
119 120 121 122 123 124 125
    SymbolVarArray ret(out.size());
    for (size_t i = 0; i < ret.size(); i++) {
        ret[i] = out[i];
    }
    return ret;
}

M
Megvii Engine Team 已提交
126
cg::OperatorNodeBase::NodeProp* BatchNormForward::do_make_node_prop() const {
127
    auto ret = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
128
    ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
129
    if (need_stats() && m_force_inplace) {
130 131 132 133 134 135
        ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
    }
    return ret;
}

void BatchNormForward::scn_do_execute() {
M
Megvii Engine Team 已提交
136 137
    auto&& x = input(0)->dev_tensor();
    auto&& y = output(5)->dev_tensor();
M
Megvii Engine Team 已提交
138
    if (need_stats()) {
M
Megvii Engine Team 已提交
139 140 141 142 143 144 145
        auto &&o0 = output(0)->dev_tensor(), &&o1 = output(1)->dev_tensor(),
             &&i0 = input(3)->dev_tensor(), &&i1 = input(4)->dev_tensor();
        mgb_assert(o0.raw_ptr() && o1.raw_ptr());  // non-empty tensor
        mgb_assert(
                o0.comp_node() == i0.comp_node() && o1.comp_node() == i1.comp_node() &&
                o0.layout().eq_layout(i0.layout()) &&
                o1.layout().eq_layout(i1.layout()));
146 147 148 149 150 151 152 153
        if (!m_force_inplace) {
            if (o0.raw_ptr() != i0.raw_ptr()) {
                o0.copy_from_fixlayout(i0);
            }
            if (o1.raw_ptr() != i1.raw_ptr()) {
                o1.copy_from_fixlayout(i1);
            }
        } else {
M
Megvii Engine Team 已提交
154
            mgb_assert(o0.raw_ptr() == i0.raw_ptr() && o1.raw_ptr() == i1.raw_ptr());
155 156
        }
    }
157 158 159 160
    mgb_assert(x.layout().eq_layout(y.layout()));
    if (x.layout().is_empty()) {
        return;
    }
M
Megvii Engine Team 已提交
161
    mgb_assert(x.layout().is_contiguous() && y.layout().is_contiguous());
162 163
    auto scale = input(1)->dev_tensor().as_megdnn();
    auto bias = input(2)->dev_tensor().as_megdnn();
M
Megvii Engine Team 已提交
164 165 166 167 168 169 170 171
    megdnn::TensorND mean, variance;
    if (param().fwd_mode == Param::FwdMode::INFERENCE) {
        mean = input(3)->dev_tensor().as_megdnn();
        variance = input(4)->dev_tensor().as_megdnn();
    } else {
        mean = output(0)->dev_tensor().as_megdnn();
        variance = output(1)->dev_tensor().as_megdnn();
    }
172 173
    auto save_mean = output(2)->dev_tensor().as_megdnn();
    auto save_variance = output(3)->dev_tensor().as_megdnn();
174
    auto reserve = output(4)->dev_tensor().as_megdnn();
175
    auto workspace = intl::get_megdnn_workspace_from_var(output().back());
M
Megvii Engine Team 已提交
176 177 178
    megdnn_opr()->exec(
            x.as_megdnn(), scale, bias, mean, variance, save_mean, save_variance,
            reserve, y.as_megdnn(), workspace);
179 180 181 182 183 184 185
}

void BatchNormForward::add_input_layout_constraint() {
    mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
}

void BatchNormForward::get_output_var_shape(
M
Megvii Engine Team 已提交
186 187 188 189 190 191 192
        const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
    mgb_assert(
            inp_shape[0].ndim == 4 && inp_shape[0].ndim == 4 && inp_shape[1].ndim == 4,
            "expect input, scale and bias to be 4 dim tensor, but "
            "got input dim: %zu, scale dim: %zu, bias dim: %zu",
            inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim);

193 194 195 196 197 198
    size_t channel_idx;
    if (param().param_dim == Param::ParamDim::DIM_111C) {
        channel_idx = 3;
    } else {
        channel_idx = 1;
    }
M
Megvii Engine Team 已提交
199
    size_t inp_c = inp_shape[0][channel_idx], scale_c = inp_shape[1][channel_idx],
200
           bias_c = inp_shape[2][channel_idx];
M
Megvii Engine Team 已提交
201 202
    mgb_assert(
            inp_c == scale_c && inp_c == bias_c,
203
            "inconsistent channel size, input channel: %zu, scale channel: %zu, bias "
M
Megvii Engine Team 已提交
204 205
            "channel: %zu",
            inp_c, scale_c, bias_c);
206

207
    out_shape[5] = inp_shape[0];
M
Megvii Engine Team 已提交
208
    for (size_t i = 0; i < 4; ++i) {
209 210
        out_shape[i] = inp_shape[1];
    }
M
Megvii Engine Team 已提交
211
    if (!need_stats()) {
212 213
        out_shape[0] = out_shape[1] = {0};
    }
214 215 216
    if (inp_shape[0].is_empty()) {
        out_shape[4] = {0};
    } else {
M
Megvii Engine Team 已提交
217 218
        out_shape[4] = {
                megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})};
219
    }
220 221 222
}

size_t BatchNormForward::get_workspace_size_bytes(
M
Megvii Engine Team 已提交
223 224
        const TensorShapeArray& input_shapes,
        const TensorShapeArray& output_shapes) const {
225 226
    if (input_shapes[0].is_empty())
        return 0;
M
Megvii Engine Team 已提交
227 228 229 230
#define in(x) \
    { input_shapes[x], input(x)->dtype() }
#define out(x) \
    { output_shapes[x], output(x)->dtype() }
231
    return megdnn_opr()->get_workspace_in_bytes(
232
            in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4), out(5));
233 234 235 236 237 238 239 240 241 242 243 244 245 246
#undef in
#undef out
}

void BatchNormForward::init_output_static_infer_desc() {
    Super::set_nr_managed_outputs(this->output().size() - 1);
    Super::init_output_static_infer_desc();
    this->init_output_static_infer_desc_workspace(
            intl::AutoAddWorkspaceNeedLimitGetter<megdnn::BNForward>::val);
}

void BatchNormForward::init_output_dtype() {
    size_t nr_inp = input().size();
    mgb_assert(input(0)->dtype().category() == input(1)->dtype().category());
M
Megvii Engine Team 已提交
247
    for (size_t i = 2; i < nr_inp; ++i) {
248 249
        mgb_assert(input(1)->dtype() == input(i)->dtype());
    }
M
Megvii Engine Team 已提交
250 251 252
    output(4)->dtype(dtype::Byte());      // reserve
    output(5)->dtype(input(0)->dtype());  // output
    for (size_t i = 0; i < 4; ++i) {
253 254 255 256
        output(i)->dtype(input(1)->dtype());
    }
}

257
void BatchNormForward::mem_plan_fwd_in2out_writable() {
M
Megvii Engine Team 已提交
258
    if (need_stats() && !m_force_inplace) {
259 260 261 262 263 264
        // TODO: testing
        output(0)->set_fwd_in2out_writable(input(3));
        output(1)->set_fwd_in2out_writable(input(4));
    }
}

265
#if MGB_ENABLE_GRAD
266
MGB_IMPL_OPR_GRAD(BatchNormForward) {
267 268
    mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
    VarNodeArray ret(opr.input().size(), nullptr);
269 270
    SymbolVarArray grad;
    switch (opr.param().fwd_mode) {
M
Megvii Engine Team 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        case BatchNorm::Param::FwdMode::TRAINING:
            grad = BatchNormBackward::make(
                    opr.input(0), out_grad[5], opr.output(2), opr.output(3),
                    opr.input(1), opr.output(4),  // reserve
                    opr.param());
            for (size_t i = 0; i < 3; ++i) {
                ret[i] = grad[(i + 2) % 3].node();
            }
            return ret;
        case BatchNorm::Param::FwdMode::INFERENCE:
            auto sqrt_var = PowC::make(
                    (SymbolVar{opr.input(4)} +
                     static_cast<dt_float32>(opr.param().epsilon)),
                    0.5, opr.config());
            auto d_bn_scale_unreduced =
                    SymbolVar{out_grad[5]} *
                    (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var;
            auto d_bn_scale = Reduce::make(
                    d_bn_scale_unreduced, Reduce::Param::Mode::SUM,
                    GetVarShape::make(opr.input(1)));
            auto d_bn_bias = Reduce::make(
                    out_grad[5], Reduce::Param::Mode::SUM,
                    GetVarShape::make(opr.input(2)));
            auto dx = SymbolVar{out_grad[5]} * SymbolVar{opr.input(1)} / sqrt_var;

            ret[0] = dx.node();
            ret[1] = d_bn_scale.node();
            ret[2] = d_bn_bias.node();
            return ret;
300
    }
301
    return ret;
302
}
303
#endif
304 305 306

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward);

M
Megvii Engine Team 已提交
307 308 309 310 311 312 313 314 315
BatchNormBackward::BatchNormBackward(
        VarNode* x, VarNode* y_grad, VarNode* save_mean, VarNode* save_variance,
        VarNode* scale, VarNode* reserve, const Param& param,
        const OperatorNodeConfig& config)
        : Super({x->owner_graph(),
                 config,
                 "batch_norm_bwd",
                 {x, y_grad, save_mean, save_variance, scale, reserve}},
                0, true) {
316
    init_megdnn_opr(*this, param);
317
    add_input({x, y_grad, save_mean, save_variance, scale, reserve});
318 319
}

M
Megvii Engine Team 已提交
320 321 322 323 324 325 326 327 328 329
SymbolVarArray BatchNormBackward::make(
        SymbolVar x, SymbolVar y_grad, SymbolVar save_mean, SymbolVar save_variance,
        SymbolVar scale, SymbolVar reserve, const Param& param,
        const OperatorNodeConfig& config) {
    auto&& out = x.node()->owner_graph()
                         ->insert_opr(std::make_unique<BatchNormBackward>(
                                 x.node(), y_grad.node(), save_mean.node(),
                                 save_variance.node(), scale.node(), reserve.node(),
                                 param, config))
                         ->output();
330 331 332 333 334 335 336 337 338
    SymbolVarArray ret(out.size());
    for (size_t i = 0; i < ret.size(); i++) {
        ret[i] = out[i];
    }
    return ret;
}

void BatchNormBackward::init_output_static_infer_desc() {
    using namespace cg::static_infer;
M
Megvii Engine Team 已提交
339 340 341 342 343
    auto&& mgr = owner_graph()->static_infer_manager();

    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(4)));
    mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(4)));
    mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(0)));
344 345 346 347 348 349 350 351 352 353 354 355 356 357
    this->init_output_static_infer_desc_workspace(
            intl::AutoAddWorkspaceNeedLimitGetter<megdnn::BNBackward>::val);
}

void BatchNormBackward::init_output_dtype() {
    mgb_assert(input(0)->dtype().category() == input(2)->dtype().category());
    mgb_assert(input(0)->dtype() == input(1)->dtype());
    mgb_assert(input(2)->dtype() == input(3)->dtype());
    mgb_assert(input(2)->dtype() == input(4)->dtype());
    output(0)->dtype(input(2)->dtype());
    output(1)->dtype(input(2)->dtype());
    output(2)->dtype(input(0)->dtype());
}

M
Megvii Engine Team 已提交
358
cg::OperatorNodeBase::NodeProp* BatchNormBackward::do_make_node_prop() const {
359
    auto ret = Super::do_make_node_prop();
M
Megvii Engine Team 已提交
360
    ret->add_dep_type_existing_var(input(5), NodeProp::DepType::VALUE_ALLOW_EMPTY);
361 362
    return ret;
}
363
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}