diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 49660fc0923a096629e0ac80f2bda1b676318f84..50ffef72baa1c5f210fd6e92de05d24a39ac86b4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -88,6 +88,7 @@ paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'poo paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)) paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)) +paddle.fluid.layers.data_norm ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'use_mkldnn', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, False, None, None, None, False)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5bc25d19cba4de6f059612e3e8c4a65b2edd0f9 --- /dev/null +++ b/paddle/fluid/operators/data_norm_op.cc @@ -0,0 +1,409 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/data_norm_op.h" +#include +#include "paddle/fluid/framework/data_layout.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +using EigenArrayMap = + Eigen::Map>; +template +using ConstEigenArrayMap = + Eigen::Map>; +template +using EigenVectorArrayMap = Eigen::Map>; +template +using ConstEigenVectorArrayMap = + Eigen::Map>; + +class DataNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Means"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Scales"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, + "Input X must have 2 to 5 dimensions."); + + const int64_t C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C); + + ctx->SetOutputDim("Y", x_dims); + ctx->SetOutputDim("Means", {C}); + ctx->SetOutputDim("Scales", {C}); + ctx->ShareLoD("X", "Y"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = ctx.Input("X")->type(); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto dn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + dn_param_type = framework::proto::VarType::FP64; + } + PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSize")->type(), + "BatchSize input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSum")->type(), + "BatchSum input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, + ctx.Input("BatchSquareSum")->type(), + "BatchSquareSum input should be of float type"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + // AddAttr("is_test", "").SetDefault(false); + AddAttr("epsilon", "") + .SetDefault(1e-4) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("data_layout", "").SetDefault("NCHW"); + AddInput("X", "The input tensor"); + AddInput("BatchSize", + "BatchSize is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("BatchSum", + "BatchSum is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("BatchSquareSum", + "The global BatchSquareSum (for training) or " + "estimated BatchSquareSum (for testing)"); + AddOutput("Y", "result after normalization"); + AddOutput("Means", + "Mean of the history data batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("Scales", + "Scales of the history data batch, " + "will apply to output when training") + .AsIntermediate(); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddComment(R"DOC( +Data Normalization. + +Can be used as a normalizer function for data +The required data format for this layer is one of the following: +1. NHWC `[batch, in_height, in_width, in_channels]` +2. NCHW `[batch, in_channels, in_height, in_width]` + +)DOC"); + } +}; + +template +class DataNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + // const bool is_test = ctx.Attr("is_test"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("Means"); + auto *scales = ctx.Output("Scales"); + + // alloc memory + y->mutable_data(ctx.GetPlace()); + + Eigen::Array inv_std(C); + ConstEigenVectorArrayMap b_size_arr( + ctx.Input("BatchSize")->data(), C); + ConstEigenVectorArrayMap b_sum_arr( + ctx.Input("BatchSum")->data(), C); + ConstEigenVectorArrayMap b_square_sum_arr( + ctx.Input("BatchSquareSum")->data(), C); + EigenVectorArrayMap means_arr(mean_out->mutable_data(ctx.GetPlace()), + C); + EigenVectorArrayMap scales_arr(scales->mutable_data(ctx.GetPlace()), + C); + means_arr = b_sum_arr / b_size_arr; + scales_arr = (b_size_arr / b_square_sum_arr).sqrt(); + + switch (data_layout) { + case DataLayout::kNCHW: // because it's two dimensions, so make no + // difference + case DataLayout::kNHWC: { + EigenArrayMap(y->mutable_data(ctx.GetPlace()), C, N) = + (ConstEigenArrayMap(x->data(), C, N).colwise() - means_arr) + .colwise() * + scales_arr; + break; + } + default: + PADDLE_THROW("Unknown storage order: %d", data_layout); + } + } +}; + +class DataNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // check input + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("Means"), ""); + PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); + + // check output + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")), + ""); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C}); + ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C}); + ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), layout, library); + } +}; + +template +class DataNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *batch_size = ctx.Input("BatchSize"); + const auto *batch_sum = ctx.Input("BatchSum"); + const auto *batch_square_sum = ctx.Input("BatchSquareSum"); + const auto *scales = ctx.Input("Scales"); + const auto *means = ctx.Input("Means"); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + // Get the size for each dimension. + // NCHW [batch_size, in_channels, in_height, in_width] + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_batch_size = + ctx.Output(framework::GradVarName("BatchSize")); + auto *d_batch_sum = ctx.Output(framework::GradVarName("BatchSum")); + auto *d_batch_square_sum = + ctx.Output(framework::GradVarName("BatchSquareSum")); + + EigenVectorArrayMap d_batch_size_arr( + d_batch_size->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap d_batch_sum_arr( + d_batch_sum->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap d_batch_square_sum_arr( + d_batch_square_sum->mutable_data(ctx.GetPlace()), C); + + d_batch_size_arr.setZero(); + d_batch_sum_arr.setZero(); + d_batch_square_sum_arr.setZero(); + + const float epsilon = ctx.Attr("epsilon"); + switch ( + data_layout) { // because it's two dimensions, so make no difference + case DataLayout::kNCHW: + case DataLayout::kNHWC: { + ConstEigenVectorArrayMap scales_arr(scales->data(), C); + ConstEigenVectorArrayMap means_arr(means->data(), C); + ConstEigenArrayMap x_arr(x->data(), C, N); + ConstEigenArrayMap d_y_arr(d_y->data(), C, N); + EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N); + d_x_arr.setZero(); + for (int nc = 0; nc < N; ++nc) { + d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + } + + // calculate data sum and squre sum + ConstEigenVectorArrayMap batch_size_arr(batch_size->data(), C); + ConstEigenVectorArrayMap batch_sum_arr(batch_sum->data(), C); + ConstEigenVectorArrayMap batch_square_sum_arr( + batch_square_sum->data(), C); + Eigen::Array sample_sum(C); + Eigen::Array sample_square_sum(C); + // calculate data sample sum and square sum + sample_sum.setZero(); + sample_square_sum.setZero(); + for (int nc = 0; nc < N; ++nc) { + sample_sum += x_arr.col(nc); + sample_square_sum += (x_arr.col(nc) - means_arr).square(); + } + // calculate gradient + d_batch_size_arr.setConstant(N); + d_batch_sum_arr = sample_sum; + d_batch_square_sum_arr = sample_square_sum + d_batch_size_arr * epsilon; + break; + } + default: + PADDLE_THROW("Unknown storage order: %s", data_layout_str); + } + } +}; + +class DataNormGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDesc(); + op->SetType("data_norm_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + + op->SetInput("BatchSize", Input("BatchSize")); + op->SetInput("BatchSum", Input("BatchSum")); + op->SetInput("BatchSquareSum", Input("BatchSquareSum")); + op->SetInput("Scales", Output("Scales")); + op->SetInput("Means", Output("Means")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("BatchSize"), InputGrad("BatchSize")); + op->SetOutput(framework::GradVarName("BatchSum"), InputGrad("BatchSum")); + op->SetOutput(framework::GradVarName("BatchSquareSum"), + InputGrad("BatchSquareSum")); + + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(data_norm, ops::DataNormOp, ops::DataNormOpMaker, + ops::DataNormGradMaker); +REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp); + +REGISTER_OP_CPU_KERNEL( + data_norm, ops::DataNormKernel, + ops::DataNormKernel); +REGISTER_OP_CPU_KERNEL( + data_norm_grad, + ops::DataNormGradKernel, + ops::DataNormGradKernel); diff --git a/paddle/fluid/operators/data_norm_op.h b/paddle/fluid/operators/data_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..63451214bcf649d0a7a949f391db9b651d237d22 --- /dev/null +++ b/paddle/fluid/operators/data_norm_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class DataNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +template +class DataNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fcf43a654e1c8275cdc05c61343f10612f397649..a4787e769f62ebbefd3ea6b70b402e660c02b576 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -58,6 +58,7 @@ __all__ = [ 'adaptive_pool2d', 'adaptive_pool3d', 'batch_norm', + 'data_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose', @@ -2897,6 +2898,133 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +def data_norm(input, + act=None, + epsilon=1e-05, + param_attr=None, + data_layout='NCHW', + in_place=False, + use_mkldnn=False, + name=None, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=False): + """ + **Data Normalization Layer** + + Can be used as a normalizer function for conv2d and fully_connected operations. + The required data format for this layer is one of the following: + + 1. NHWC `[batch, in_height, in_width, in_channels]` + + 2. NCHW `[batch, in_channels, in_height, in_width]` + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Args: + input(variable): The input variable which is a LoDTensor. + act(string, Default None): Activation type, linear|relu|prelu|... + epsilon(float, Default 1e-05): + param_attr(ParamAttr): The parameter attribute for Parameter `scale`. + data_layout(string, default NCHW): NCHW|NHWC + in_place(bool, Default False): Make the input and output of batch norm reuse memory. + use_mkldnn(bool, Default false): ${use_mkldnn_comment} + name(string, Default None): A name for this layer(optional). If set None, the layer + will be named automatically. + moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. + moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. + do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. + + Returns: + Variable: A tensor variable which is the result after applying data normalization on the input. + + Examples: + + .. code-block:: python + + data = fluid.layers.data(input=x, size=200, param_attr='fc1.w') + hidden2 = fluid.layers.data_norm(input=hidden1) + """ + helper = LayerHelper('data_norm', **locals()) + dtype = helper.input_dtype() + + input_shape = input.shape + if data_layout == 'NCHW': + channel_num = input_shape[1] + else: + if data_layout == 'NHWC': + channel_num = input_shape[-1] + else: + raise ValueError("unsupported data layout:" + data_layout) + + param_shape = [channel_num] + + batch_size_default = 1e4 + batch_sum_default = 0.0 + batch_square_sum_default = 1e4 + + if param_attr and isinstance(param_attr, dict): + batch_size_default = param_attr.get("batch_size", 1e4) + batch_sum_default = param_attr.get("batch_sum", 0.0) + batch_square_sum_default = param_attr.get("batch_square", 1e4) + + # create parameter + batch_size = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_size', + initializer=Constant(value=float(batch_size_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + batch_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_sum', + initializer=Constant(value=float(batch_sum_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + batch_square_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_square_sum', + initializer=Constant(value=float(batch_square_sum_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + means = helper.create_variable(dtype=dtype, stop_gradient=True) + scales = helper.create_variable(dtype=dtype, stop_gradient=True) + + data_norm_out = input if in_place else helper.create_variable(dtype=dtype) + + helper.append_op( + type="data_norm", + inputs={ + "X": input, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + }, + outputs={"Y": data_norm_out, + "Means": means, + "Scales": scales}, + attrs={"epsilon": epsilon, + "use_mkldnn": use_mkldnn}) + + return helper.append_activation(data_norm_out) + + @templatedoc() def layer_norm(input, scale=True, @@ -3065,9 +3193,9 @@ def group_norm(input, inputs['Bias'] = bias # create output - mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - group_norm_out = helper.create_tmp_variable(dtype) + mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_variable(dtype) helper.append_op( type="group_norm",