diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f7b7b095361bfb59635896d5925f245c5171f017..46cf27711d37261d0ffaed30e04bac8385b2ff13 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1122,7 +1122,8 @@ def batch_norm( momentum: float = 0.9, eps: float = 1e-5, inplace: bool = True, - compute_mode="default" + compute_mode="default", + param_dim="dim_1c11" ): r"""Applies batch normalization to the input. @@ -1147,16 +1148,23 @@ def batch_norm( if inp.ndim != 4: raise NotImplementedError("batch_norm for ndim != 4") - C = inp.shape[1] + if param_dim == "dim_1c11": + C = inp.shape[1] + pshape = (1, C, 1, 1) + elif param_dim == "dim_111c": + C = inp.shape[3] + pshape = (1, 1, 1, C) + else: + raise ValueError("Invalid param_dim {}".format(param_dim)) def make_full_if_none(x, value): if x is None: (x,) = Const(value, dtype=inp.dtype, device=inp.device)() - shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) + shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) return result elif x.ndim == 1: - shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device) + shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Reshape(), x, shape) return result return x @@ -1183,19 +1191,19 @@ def batch_norm( if not training: op = builtin.BatchNorm( - fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" + fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim ) ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] return ret else: op = builtin.BatchNorm( - avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" + avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim ) if has_mean or has_var: running_mean = make_full_if_none(running_mean, 0) running_var = make_full_if_none(running_var, 1) - new_mean, new_var, _, _, inp = apply( + new_mean, new_var, *_, inp = apply( op, inp, weight, bias, running_mean, running_var ) if not has_mean: @@ -1213,7 +1221,7 @@ def batch_norm( else: return inp, new_mean, new_var else: - (_, _, inp,) = apply(op, inp, weight, bias) + inp = apply(op, inp, weight, bias)[-1] return inp diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index b7100339d0e2edb07b05325de62b51afb7cb432e..a4479077f3505631b818f4cc2245af78e95858fa 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -27,6 +27,7 @@ class _BatchNorm(Module): track_running_stats=True, freeze=False, compute_mode="default", + param_dim="dim_1c11", **kwargs ): super(_BatchNorm, self).__init__(**kwargs) @@ -38,6 +39,7 @@ class _BatchNorm(Module): self._track_running_stats_saved = track_running_stats self.freeze = freeze self.compute_mode = compute_mode + self.param_dim = param_dim if self.freeze: assert ( self._track_running_stats_saved @@ -125,6 +127,7 @@ class _BatchNorm(Module): momentum=exponential_average_factor, eps=self.eps, compute_mode=self.compute_mode, + param_dim=self.param_dim, ) if _ndims != 4: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index e3e7986debd3b93e9d42dc476e3326d17e29c6b3..bd73dbb81f953b10178ee31bce09daac7e5fde2a 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -811,7 +811,8 @@ def test_batch_conv_bias(): run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) -def test_conv2d_io16c32(): +def test_conv2d_autocast(): + """check amp's result is equal to manually converted result""" amp.enabled = True inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32) @@ -918,11 +919,14 @@ def test_layer_norm(): assert abs(outvar.mean()) < 1e-7 -def test_batchnorm2d_io16c32(): +def test_batchnorm2d_autocast(): + """check amp's result is equal to manually converted result""" amp.enabled = True - inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32) - weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32) - bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32) + tshape = (1, 224, 224, 3) + pshape = (1, 1, 1, 3) + inp = tensor(np.random.randn(*tshape), dtype=np.float32) + weight = tensor(np.ones(pshape, dtype=np.float32)) + bias = tensor(np.zeros(pshape, dtype=np.float32)) out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False) diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index cad0b7dad069719ca078aa0ceb95f3f580f54a47..117068d1a668cfbc5118af9a9cb4df980695ac4f 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -51,16 +51,16 @@ std::tuple, bool> infer_output_attrs_fallible( "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); // need running mean/variance bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; - size_t nr_out = need_stat? 5 : 3; + size_t nr_out = need_stat? 6 : 4; SmallVector out_shapes(nr_out); auto&& i0 = inputs[0]; auto&& i1 = inputs[1]; // [running_mean, running_var,] save_mean, save_var - for (size_t i = 0; i < nr_out-1; ++ i) { + for (size_t i = 0; i < nr_out-2; ++ i) { out_shapes[i] = {i1.layout, i1.comp_node}; } - // output tensor - out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; + out_shapes[nr_out-2] = {TensorLayout({0}, dtype::Byte()), i0.comp_node}; // reserve + out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; // output return {out_shapes, out_shapes[nr_out-1].layout.ndim != 0}; } diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index a0c7aaee41c7942e5a28e44b8f18b06fd0783b50..f5d82ba5d4bae08c22e5875f461bf13e6fa5f5f6 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -689,7 +689,8 @@ ProxyGraph::make_backward_graph( output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); } auto output_grads = make_input_place_holders(output_descs); - mgb_assert(output_grads.size() == output_has_grad.size()); + mgb_assert(output_grads.size() == output_has_grad.size(), "%d vs %d", + output_grads.size(), output_has_grad.size()); bool any_input_has_grad = false; for (size_t i = 0; i < output_grads.size(); ++ i) { if (!output_has_grad[i]) { diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index 0dddc73b113b7e329c095afb7f2911529afd96e2..90e53ce7a77f0e5895d593e569859412ce60f873 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -207,7 +207,7 @@ TEST(TestImperative, BatchNormGrad) { attr.param.write_pod(param); OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, {true, true, true, false, false}, - {false, false, false, false, true}); + {false, false, false, false, false, true}); } { auto op = OprAttr::make("BatchNorm"); @@ -216,7 +216,7 @@ TEST(TestImperative, BatchNormGrad) { param.fwd_mode = Param::FwdMode::TRAINING; attr.param.write_pod(param); OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, - {false, false, true}); + {false, false, false, true}); } } diff --git a/imperative/src/test/helper.cpp b/imperative/src/test/helper.cpp index 685fd05905ce65f1c050b0d350ac39df4648019e..cd5b86045662f325b0506e4c1f7c44f77f287acf 100644 --- a/imperative/src/test/helper.cpp +++ b/imperative/src/test/helper.cpp @@ -99,7 +99,7 @@ UNUSED void print(const char* s) { OprChecker::OprChecker(std::shared_ptr opdef) : m_op(opdef) {} -void OprChecker::run(std::vector inp_keys) { +void OprChecker::run(std::vector inp_keys, std::set bypass) { HostTensorGenerator<> gen; size_t nr_inps = inp_keys.size(); SmallVector host_inp(nr_inps); @@ -151,6 +151,8 @@ void OprChecker::run(std::vector inp_keys) { func->execute().wait(); // run last because it may contain inplace operations for(size_t i = 0; i < nr_oups; ++ i) { + if (bypass.find(i) != bypass.end()) + continue; MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]); } } diff --git a/imperative/src/test/helper.h b/imperative/src/test/helper.h index 22dbbb9e10c4134dd99626c5b9e19279ebb44e54..f8cd19997d84e273a3221f15f3061dc829c2be50 100644 --- a/imperative/src/test/helper.h +++ b/imperative/src/test/helper.h @@ -23,7 +23,7 @@ class OprChecker { public: using InputSpec = std::variant; OprChecker(std::shared_ptr opdef); - void run(std::vector inp_shapes); + void run(std::vector inp_shapes, std::set bypass={}); private: std::shared_ptr m_op; }; diff --git a/imperative/src/test/imperative.cpp b/imperative/src/test/imperative.cpp index bcf081d4017d669976fdbda1dd63b5367ab19bc2..329370f3c2bdf18458791dd4016f9063c5aada75 100644 --- a/imperative/src/test/imperative.cpp +++ b/imperative/src/test/imperative.cpp @@ -73,7 +73,7 @@ TEST(TestImperative, BatchNorm) { TensorShape{1, C, 1, 1}, TensorShape{1, C, 1, 1}, TensorShape{1, C, 1, 1} - }); + }, {4}); } TEST(TestImperative, Concat) { diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index b3d3a78ed22702eb826c1629dca1931d4b51919a..e222aea8bb8466eb462648004a5afde78351de76 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1766,7 +1766,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { x.dtype().name(), res.dtype().name()); } rewriter.replace_var( - opr->output(4), res.node(), + opr->output(5), res.node(), mgb_cstr_log( "replace batch_norm(x, scale, bias, mean, " "varience) " diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 1977b5ac53415c5fb08fd5d129ea4c3dca1ee9e8..3b8f59cf07139de670ebf10be98ad90cc8f50c9b 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -35,6 +35,7 @@ #include "megdnn/tensor_format.h" #include +#include #if MGB_CUDA #include @@ -1665,44 +1666,49 @@ TEST(TestGoptInference, concatbypass) { TEST(TestGoptInference, ConvertBatchNormPass) { auto cn = CompNode::load("cpu0"); - HostTensorGenerator<> gen(0, 1, 0); - auto graph = ComputingGraph::make(); - graph->options().graph_opt_level = 0; - auto mkvar = [&](const char* name, const TensorShape& shp) { - return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); - }; - auto mkcvar = [&](const char* name, const TensorShape& shp) { - return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) - .rename(name); - }; - using Param = opr::BatchNorm::Param; - Param param(Param::ParamDim::DIM_1C11, Param::FwdMode::INFERENCE); - TensorShape shp = {1, 3, 1, 1}; - auto x = mkvar("x", {2, 3, 16, 24}), scale = mkcvar("scale", shp), - bias = mkcvar("bias", shp), mean = mkcvar("mean", shp); - auto host_variance = gen(shp, cn); - for (size_t i = 0; i < shp.total_nr_elems(); ++i) { - host_variance->ptr()[i] = - std::abs(host_variance->ptr()[i]); - } - auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance) - .rename("variance"); - auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[4]; - SymbolVar y_opt; - unpack_vector(gopt::optimize_for_inference( - {y}, gopt::OptimizeForInferenceOptions{}), - y_opt); - ASSERT_EQ(0u, find_opr_num(y_opt)); - graph->compile({{y_opt, {}}}) - ->to_json() - ->writeto_fpath( - output_file("TestGoptInference.ConvertBatchNormPass.json")); + std::vector shps = {{1, 3, 1, 1}, {1, 1, 1, 3}}, + xshps = {{2, 3, 16, 24}, {2, 16, 24, 3}}; + for (int t = 0; t < 2; t++) { + HostTensorGenerator<> gen(0, 1, 0); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + using Param = opr::BatchNorm::Param; + Param::ParamDim param_dim = t == 0 ? Param::ParamDim::DIM_1C11 : Param::ParamDim::DIM_111C; + Param param(param_dim, Param::FwdMode::INFERENCE); + TensorShape shp = shps[t], xshp = xshps[t]; + auto x = mkvar("x", xshp), scale = mkcvar("scale", shp), + bias = mkcvar("bias", shp), mean = mkcvar("mean", shp); + auto host_variance = gen(shp, cn); + for (size_t i = 0; i < shp.total_nr_elems(); ++i) { + host_variance->ptr()[i] = + std::abs(host_variance->ptr()[i]); + } + auto variance = opr::SharedDeviceTensor::make(*graph, *host_variance) + .rename("variance"); + auto y = opr::BatchNorm::make(x, scale, bias, mean, variance, param)[5]; + SymbolVar y_opt; + unpack_vector(gopt::optimize_for_inference( + {y}, gopt::OptimizeForInferenceOptions{}), + y_opt); + ASSERT_EQ(0u, find_opr_num(y_opt)); + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ConvertBatchNormPass.json")); - HostTensorND host_y, host_y_opt; - auto func = graph->compile({make_callback_copy(y, host_y), - make_callback_copy(y_opt, host_y_opt)}); - func->execute(); - MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); + HostTensorND host_y, host_y_opt; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); + } } TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 56363f0fbfcd9169dc5bd4627ee05b36ef95b7ed..72e70ec45093bb72b7afb20c90d9ce7998e61ded 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -62,10 +62,12 @@ BatchNormForward::BatchNormForward(VarNode *x, } init_megdnn_opr(*this, param); - output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_input({x, scale, bias, mean, variance}); + output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve + output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + // running mean/var if (param.fwd_mode == Param::FwdMode::INFERENCE) { auto mark_empty_var = [&](VarNode *var) { var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) @@ -92,9 +94,10 @@ BatchNormForward::BatchNormForward(VarNode *x, {x, scale, bias}} { init_megdnn_opr(*this, param); - output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_input({x, scale, bias}); + output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); // reserve + output(5)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); auto mark_empty_var = [&](VarNode *var) { var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) .add_flag(VarNode::Flag::VOLATILE_CONTENT); @@ -151,7 +154,7 @@ BatchNormForward::do_make_node_prop() const { void BatchNormForward::scn_do_execute() { auto &&x = input(0)->dev_tensor(); - auto &&y = output(4)->dev_tensor(); + auto &&y = output(5)->dev_tensor(); if (need_stats()) { auto &&o0 = output(0)->dev_tensor(), &&o1 = output(1)->dev_tensor(), @@ -192,9 +195,10 @@ void BatchNormForward::scn_do_execute() { } auto save_mean = output(2)->dev_tensor().as_megdnn(); auto save_variance = output(3)->dev_tensor().as_megdnn(); + auto reserve = output(4)->dev_tensor().as_megdnn(); auto workspace = intl::get_megdnn_workspace_from_var(output().back()); megdnn_opr()->exec(x.as_megdnn(), scale, bias, mean, variance, - save_mean, save_variance, y.as_megdnn(), workspace); + save_mean, save_variance, reserve, y.as_megdnn(), workspace); } void BatchNormForward::add_input_layout_constraint() { @@ -208,18 +212,25 @@ void BatchNormForward::get_output_var_shape( "expect input, scale and bias to be 4 dim tensor, but " "got input dim: %zu, scale dim: %zu, bias dim: %zu", inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); - - size_t inp_c = inp_shape[0][1], - scale_c = inp_shape[1][1], - bias_c = inp_shape[2][1]; + + size_t channel_idx; + if (param().param_dim == Param::ParamDim::DIM_111C) { + channel_idx = 3; + } else { + channel_idx = 1; + } + size_t inp_c = inp_shape[0][channel_idx], + scale_c = inp_shape[1][channel_idx], + bias_c = inp_shape[2][channel_idx]; mgb_assert(inp_c == scale_c && inp_c == bias_c, "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", inp_c, scale_c, bias_c); - out_shape[4] = inp_shape[0]; + out_shape[5] = inp_shape[0]; for (size_t i = 0; i < 4; ++ i) { out_shape[i] = inp_shape[1]; } + out_shape[4] = {megdnn_opr()->get_reserve_in_bytes({inp_shape[0], input(0)->dtype()})}; if (!need_stats()) { out_shape[0] = out_shape[1] = {0}; } @@ -231,7 +242,7 @@ size_t BatchNormForward::get_workspace_size_bytes( #define in(x) {input_shapes[x], input(x)->dtype()} #define out(x) {output_shapes[x], output(x)->dtype()} return megdnn_opr()->get_workspace_in_bytes( - in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4)); + in(0), in(1), in(2), out(0), out(1), out(2), out(3), out(4), out(5)); #undef in #undef out } @@ -249,7 +260,8 @@ void BatchNormForward::init_output_dtype() { for (size_t i = 2; i < nr_inp; ++ i) { mgb_assert(input(1)->dtype() == input(i)->dtype()); } - output(4)->dtype(input(0)->dtype()); + output(4)->dtype(dtype::Byte()); // reserve + output(5)->dtype(input(0)->dtype()); // output for (size_t i = 0; i < 4; ++ i) { output(i)->dtype(input(1)->dtype()); } @@ -271,9 +283,10 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { switch (opr.param().fwd_mode) { case BatchNorm::Param::FwdMode::TRAINING: grad = BatchNormBackward::make( - opr.input(0), out_grad[4], + opr.input(0), out_grad[5], opr.output(2), opr.output(3), - opr.input(1), opr.param()); + opr.input(1), opr.output(4), // reserve + opr.param()); for (size_t i = 0; i < 3; ++ i) { ret[i] = grad[(i + 2) % 3].node(); } @@ -281,13 +294,13 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { case BatchNorm::Param::FwdMode::INFERENCE: auto sqrt_var = PowC::make((SymbolVar{opr.input(4)} + static_cast(opr.param().epsilon)), 0.5, opr.config()); - auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} * + auto d_bn_scale_unreduced = SymbolVar{out_grad[5]} * (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var; auto d_bn_scale = Reduce::make(d_bn_scale_unreduced, Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1))); - auto d_bn_bias = Reduce::make(out_grad[4], + auto d_bn_bias = Reduce::make(out_grad[5], Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2))); - auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var; + auto dx = SymbolVar{out_grad[5]} * SymbolVar{opr.input(1)} / sqrt_var; ret[0] = dx.node(); ret[1] = d_bn_scale.node(); @@ -302,26 +315,26 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); BatchNormBackward::BatchNormBackward(VarNode *x, VarNode *y_grad, VarNode *save_mean, - VarNode* save_variance, VarNode *scale, + VarNode* save_variance, VarNode *scale, VarNode *reserve, const Param ¶m, const OperatorNodeConfig &config): Super({x->owner_graph(), config, "batch_norm_bwd", - {x, y_grad, save_mean, save_variance, scale}}, + {x, y_grad, save_mean, save_variance, scale, reserve}}, 0, true) { init_megdnn_opr(*this, param); - add_input({x, y_grad, save_mean, save_variance, scale}); + add_input({x, y_grad, save_mean, save_variance, scale, reserve}); } SymbolVarArray BatchNormBackward::make(SymbolVar x, SymbolVar y_grad, SymbolVar save_mean, - SymbolVar save_variance, SymbolVar scale, + SymbolVar save_variance, SymbolVar scale, SymbolVar reserve, const Param ¶m, const OperatorNodeConfig &config) { auto&& out = x.node() ->owner_graph() ->insert_opr(std::make_unique( x.node(), y_grad.node(), save_mean.node(), - save_variance.node(), scale.node(), param, config)) + save_variance.node(), scale.node(), reserve.node(), param, config)) ->output(); SymbolVarArray ret(out.size()); for (size_t i = 0; i < ret.size(); i++) { @@ -355,4 +368,11 @@ void BatchNormBackward::init_output_dtype() { output(2)->dtype(input(0)->dtype()); } +cg::OperatorNodeBase::NodeProp* +BatchNormBackward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(5), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index bdaa9b06789ebd05114a23bcc5dd0e4d316db41c..c0fb6b645839808851670ad7fe8094726c33a2f2 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -391,14 +391,14 @@ struct OprMaker { }; template <> -struct OprMaker { +struct OprMaker { using Param = opr::BatchNormBackward::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); - return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], param, + return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0] .node() ->owner_opr(); @@ -576,7 +576,7 @@ using ConvBiasForwardV4 = ConvBiasForward; MGB_SEREG_OPR(ConvBiasForwardV4, 0); MGB_SEREG_OPR(BatchNorm, 0); -MGB_SEREG_OPR(BatchNormBackward, 5); +MGB_SEREG_OPR(BatchNormBackward, 6); using LocalShareForwardV1 = LocalShareForward; using LocalShareBackwardDataV1 = LocalShareBackwardData; diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl index 3ad5e83e81fc387289ce1ed5acb48b886be6a1c2..f44802834172d9e577211d9e4513e1d5192b3dbb 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.inl +++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl @@ -183,6 +183,10 @@ namespace { #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" +#define _NR_INPUTS 6 +#define _NR_OUTPUTS 3 +#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) +#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" } // anonymous namespace /* ======================= MegDNNOprWrapperFwd ======================= */ diff --git a/src/opr/include/megbrain/opr/dnn/batch_norm.h b/src/opr/include/megbrain/opr/dnn/batch_norm.h index 774582dea7c9bc903442b728b8859a81a1ddeb10..0e264b52abdfcc1f6b542e3b2cd73bfc3045c86c 100644 --- a/src/opr/include/megbrain/opr/dnn/batch_norm.h +++ b/src/opr/include/megbrain/opr/dnn/batch_norm.h @@ -24,7 +24,7 @@ namespace opr { /* input: * x, scale, bias, [running_mean, running_variance] * output: - * running_mean, running_variance, save_mean, save_inv_variance, y + * running_mean, running_variance, save_mean, save_inv_variance, reserve, y * * All params have the same definition with cudnn batch normalization. * @@ -35,6 +35,9 @@ namespace opr { * * For statistic(mean and variance) update: * running_mean = (1 - moving_average) * running_mean + moving_average * new_mean + * + * Output reserve is used for cudnnBatchNormalizationForwardTrainingEx, and should + * be preserved for backward. */ MGB_DEFINE_OPR_CLASS(BatchNormForward, cg::OutshapePureByInshapeOpr< @@ -86,7 +89,7 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, using BatchNorm = BatchNormForward; /* input: - * x, y_grad, save_mean, save_inv_variance, scale + * x, y_grad, save_mean, save_inv_variance, scale, reserve * output: * scale_grad, bias_grad, x_grad */ @@ -97,15 +100,17 @@ MGB_DEFINE_OPR_CLASS(BatchNormBackward, public: BatchNormBackward(VarNode *x, VarNode *y_grad, VarNode *save_mean, VarNode *save_variance, - VarNode *scale, + VarNode *scale, VarNode *reserve, const Param ¶m, const OperatorNodeConfig &config); static SymbolVarArray make(SymbolVar x, SymbolVar y_grad, SymbolVar save_mean, SymbolVar save_variance, SymbolVar scale, + SymbolVar reserve, const Param ¶m = {}, const OperatorNodeConfig &config = {}); private: + NodeProp* do_make_node_prop() const override; void init_output_static_infer_desc() override; void init_output_dtype() override; }; diff --git a/src/opr/test/dnn/batch_norm.cpp b/src/opr/test/dnn/batch_norm.cpp index e3a28c35f22fbfabc2d46cbcbb76001907f968ce..1fb6a2a5749d47e3d8955617fba9a2b80d3f07c9 100644 --- a/src/opr/test/dnn/batch_norm.cpp +++ b/src/opr/test/dnn/batch_norm.cpp @@ -95,13 +95,13 @@ SymbolVarArray batch_norm(const SymbolVarArray& inputs, const Param ¶m) { SymbolVarArray ret; if (inputs.size() == 3) { ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], param); - return {ret[4], ret[2], ret[3]}; + return {ret[5], ret[2], ret[3]}; } else { mgb_assert(inputs.size() == 5); ret = opr::BatchNorm::make(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], param); - return {ret[4], ret[0], ret[1]}; + return {ret[5], ret[0], ret[1]}; } }