From 38e5cd000ec24406927ea4d334e3ad30f6a81d18 Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Wed, 17 May 2023 13:00:46 +0800 Subject: [PATCH] [fluid] decoupling abn op (#53826) --- paddle/fluid/operators/inplace_abn_op.cc | 254 ++++++++++++++++++++++- 1 file changed, 244 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index 5deffde2b56..62197787f3b 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -17,17 +17,159 @@ #include #include #include - -#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/phi/kernels/batch_norm_grad_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h" namespace paddle { namespace operators { -class InplaceABNOp : public paddle::operators::BatchNormOp { +class InplaceABNOp : public framework::OperatorWithKernel { public: - using paddle::operators::BatchNormOp::BatchNormOp; + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); + + bool is_test = ctx->Attrs().Get("is_test"); + bool trainable_stats = ctx->Attrs().Get("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + if (!test_mode) { + OP_INOUT_CHECK( + ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm"); + OP_INOUT_CHECK( + ctx->HasOutput("VarianceOut"), "Output", "VarianceOut", "BatchNorm"); + OP_INOUT_CHECK( + ctx->HasOutput("SavedMean"), "Output", "SavedMean", "BatchNorm"); + OP_INOUT_CHECK(ctx->HasOutput("SavedVariance"), + "Output", + "SavedVariance", + "BatchNorm"); + } + + // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python + PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], + ctx->Outputs("MeanOut")[0], + platform::errors::InvalidArgument( + "Mean and MeanOut should share the same memory")); + PADDLE_ENFORCE_EQ( + ctx->Inputs("Variance")[0], + ctx->Outputs("VarianceOut")[0], + platform::errors::InvalidArgument( + "Variance and VarianceOut should share the same memory")); + + const auto x_dims = ctx->GetInputDim("X"); + + for (int i = 0; i < x_dims.size(); i++) { + PADDLE_ENFORCE_EQ( + (x_dims[i] == -1) || (x_dims[i] > 0), + true, + platform::errors::InvalidArgument( + "Each dimension of input tensor is expected to be -1 or a " + "positive number, but received %d. Input's shape is [%s].", + x_dims[i], + x_dims)); + } + + const DataLayout data_layout = + phi::StringToDataLayout(ctx->Attrs().Get("data_layout")); + + if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) { + auto mom = ctx->Inputs("MomentumTensor"); + PADDLE_ENFORCE_EQ(mom.size(), + 1, + platform::errors::InvalidArgument( + "The input tensor MomentumTensor's size must be 1" + "But received: MomentumTensor's size is [%d]", + mom.size())); + } + + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + platform::errors::InvalidArgument( + "ShapeError: the dimension of input " + "X must greater than or equal to 2. But received: " + "the shape of input " + "X = [%s], the dimension of input X =[%d]", + x_dims, + x_dims.size())); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + platform::errors::InvalidArgument( + "ShapeError: the dimension of input X " + "must smaller than or equal to 5. But received: the " + "shape of input X " + "= [%s], the dimension of input X = [%d]", + x_dims, + x_dims.size())); + VLOG(4) << ctx->IsRunMKLDNNKernel(); + VLOG(4) << data_layout; + const int64_t C = ((ctx->IsRunMKLDNNKernel() == true) || + (data_layout == DataLayout::kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + auto scale_dim = ctx->GetInputDim("Scale"); + auto bias_dim = ctx->GetInputDim("Bias"); + + PADDLE_ENFORCE_EQ( + scale_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of scale must equal to 1." + "But received: the shape of scale is [%s], the dimension " + "of scale is [%d]", + scale_dim, + scale_dim.size())); + PADDLE_ENFORCE_EQ( + bias_dim.size(), + 1UL, + platform::errors::InvalidArgument( + "ShapeError: the dimension of bias must equal to 1." + "But received: the shape of bias is [%s],the dimension " + "of bias is [%d]", + bias_dim, + bias_dim.size())); + + bool check = true; + if ((!ctx->IsRuntime()) && + (phi::product(scale_dim) <= 0 || phi::product(bias_dim) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(scale_dim[0], + C, + platform::errors::InvalidArgument( + "ShapeError: the shape of scale must equal to [%d]" + "But received: the shape of scale is [%d]", + C, + scale_dim[0])); + PADDLE_ENFORCE_EQ(bias_dim[0], + C, + platform::errors::InvalidArgument( + "ShapeError: the shape of bias must equal to [%d]" + "But received: the shape of bias is [%d]", + C, + bias_dim[0])); + } + ctx->SetOutputDim("Y", x_dims); + ctx->ShareLoD("X", "Y"); + VLOG(4) << x_dims; + ctx->SetOutputDim("MeanOut", {C}); + ctx->SetOutputDim("VarianceOut", {C}); + if (!test_mode) { + ctx->SetOutputDim("SavedMean", {C}); + ctx->SetOutputDim("SavedVariance", {C}); + } + if (ctx->HasOutput("ReserveSpace")) { + ctx->SetOutputDim("ReserveSpace", {-1}); + } + } protected: phi::KernelKey GetExpectedKernelType( @@ -65,10 +207,9 @@ class InplaceABNOp : public paddle::operators::BatchNormOp { } }; -class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { +class InplaceABNGradOp : public framework::OperatorWithKernel { public: - using paddle::operators::BatchNormGradOp::BatchNormGradOp; - + using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { // check input OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad"); @@ -155,10 +296,82 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { } }; -class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker { +class InplaceABNOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - BatchNormOpMaker::Make(); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("momentum", "").SetDefault(0.9); + AddAttr("epsilon", "") + .SetDefault(1e-5) + .AddCustomChecker([](const float& epsilon) { + PADDLE_ENFORCE_GE( + epsilon, + 0.0f, + platform::errors::InvalidArgument( + "'epsilon' should be greater or equal than 0.0.")); + PADDLE_ENFORCE_LE( + epsilon, + 0.001f, + platform::errors::InvalidArgument( + "'epsilon' should be less or equal than 0.001.")); + }); + AddAttr("data_layout", "").SetDefault("NCHW"); + AddInput("X", "The input tensor"); + AddInput("Scale", + "Scale is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Bias", + "Bias is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Mean", + "The global mean (for training) or " + "estimated mean (for testing)"); + AddInput("Variance", + "The global variance (for training) " + "or estimated Variance (for testing)"); + AddInput( + "MomentumTensor", + "(phi::DenseTensor, optional) If provided, batch_norm will " + "use this as momentum, this has a higher priority than " + "attr(momentum), the shape of this tensor MUST BE [1].") + .AsDispensable(); + AddOutput("Y", "result after normalization"); + AddOutput("MeanOut", + "Share memory with Mean. " + "Store the global mean when training"); + AddOutput("VarianceOut", + "Share memory with Variance. " + "Store the global Variance when training"); + AddOutput("SavedMean", + "Mean of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("SavedVariance", + "Variance of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("ReserveSpace", + "Reserve GPU space for triggering the new semi-persistent " + "NHWC kernel") + .AsDispensable() + .AsExtra(); + AddAttr("use_global_stats", + "(bool, default false) Whether to use global mean and " + "variance. In inference or test mode, set use_global_stats " + "to true or is_test true. the behavior is equivalent. " + "In train mode, when setting use_global_stats True, the " + "global mean and variance are also used during train time, " + "the BN acts as scaling and shiffting.") + .SetDefault(false); + AddAttr( + "trainable_statistics", + "(bool, default false) Whether to calculate mean and variance " + "in test mode. If setting true in test mode, mean and variace " + "will be calculated by current batch statistics.") + .SetDefault(false); AddAttr( "activation", "(enum string, default identity, can be identity|elu|leaky-relu) " @@ -174,6 +387,17 @@ class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker { "(bool, default false) Whether use synchronize batch " "normalization.") .SetDefault(false); + AddComment(R"DOC( +Batch Normalization. + +Batch Norm has been implemented as discussed in the paper: +https://arxiv.org/pdf/1502.03167.pdf +Can be used as a normalizer function for conv2d and fully_connected operations. +The required data format for this layer is one of the following: +1. NHWC `[batch, in_height, in_width, in_channels]` +2. NCHW `[batch, in_channels, in_height, in_width]` + +)DOC"); } }; @@ -358,6 +582,16 @@ class InplaceABNGradKernel : public framework::OpKernel { } }; +class InplaceABNOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; + } +}; + } // namespace operators } // namespace paddle @@ -367,7 +601,7 @@ DECLARE_INPLACE_OP_INFERER(InplaceAbnOpInplaceInferer, {"X", "Y"}); REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker, - ops::BatchNormOpInferVarType, + ops::InplaceABNOpInferVarType, ops::InplaceABNOpGradMaker, ops::InplaceABNOpGradMaker, InplaceAbnOpInplaceInferer) -- GitLab