// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/inplace_abn_op.h" #include #include #include #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/phi/kernels/batch_norm_grad_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h" namespace paddle { namespace operators { class InplaceABNOp : public paddle::operators::BatchNormOp { public: using paddle::operators::BatchNormOp::BatchNormOp; protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). auto bn_param_type = framework::proto::VarType::FP32; if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } PADDLE_ENFORCE_EQ(bn_param_type, framework::TransToProtoVarType( ctx.Input("Scale")->dtype()), platform::errors::InvalidArgument( "Scale input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, framework::TransToProtoVarType( ctx.Input("Bias")->dtype()), platform::errors::InvalidArgument( "Bias input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, framework::TransToProtoVarType( ctx.Input("Mean")->dtype()), platform::errors::InvalidArgument( "Mean input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, framework::TransToProtoVarType( ctx.Input("Variance")->dtype()), platform::errors::InvalidArgument( "Variance input should be of float type")); return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { public: using paddle::operators::BatchNormGradOp::BatchNormGradOp; void InferShape(framework::InferShapeContext* ctx) const override { // check input OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", "Y@GRAD", "InplaceABNGrad"); OP_INOUT_CHECK( ctx->HasInput("SavedMean"), "Input", "SavedMean", "InplaceABNGrad"); OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", "InplaceABNGrad"); // check output OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X@GRAD", "InplaceABNGrad"); const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); PADDLE_ENFORCE_EQ( has_scale_grad, has_bias_grad, platform::errors::InvalidArgument( "Output(Scale@GRAD) and Output(Bias@GRAD) must be null " "or not be null at same time. But now, " "has Scale@Grad=[%d], has Bias@GRAD=[%d]", has_scale_grad, has_bias_grad)); const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); if (use_global_stats) { PADDLE_ENFORCE_EQ( !ctx->Attrs().Get("use_mkldnn"), true, platform::errors::InvalidArgument( "Using global stats during training is not supported " "in oneDNN version of batch_norm_gradient kernel now.")); } OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad"); const auto y_dims = ctx->GetInputDim("Y"); const DataLayout data_layout = phi::StringToDataLayout(ctx->Attrs().Get("data_layout")); const int C = ((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW) ? y_dims[1] : y_dims[y_dims.size() - 1]); ctx->SetOutputDim(framework::GradVarName("X"), y_dims); // has_scale_grad == has_bias_grad, judge has_scale_grad is enough if (has_scale_grad) { ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } } protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto* var = ctx.InputVar(framework::GradVarName("Y")); auto input_data_type = framework::TransToProtoVarType( ctx.Input("Y")->dtype()); if (var == nullptr) { PADDLE_THROW(platform::errors::InvalidArgument( "can't find gradient variable of Y")); } const phi::DenseTensor* t = nullptr; if (var->IsType()) { t = &var->Get(); } else if (var->IsType()) { t = &var->Get(); } if (t == nullptr) { PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } return phi::KernelKey(input_data_type, ctx.GetPlace()); } }; class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker { public: void Make() override { BatchNormOpMaker::Make(); AddAttr( "activation", "(enum string, default identity, can be identity|elu|leaky-relu) " "The activation type used for output candidate {h}_t.") .SetDefault(""); AddAttr("alpha", "(float, default 1.0) Only used in inplace-abn kernel," "the activation type(identity|elu|leakyrelu) would be fused " "with batch_norm, " "this is the alpha value for elu|leakyrelu.") .SetDefault(0.1f); AddAttr("use_sync_bn", "(bool, default false) Whether use synchronize batch " "normalization.") .SetDefault(false); } }; template class InplaceABNOpGradMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; protected: void Apply(GradOpPtr op) const override { op->SetType(this->ForwardOpType() + "_grad"); op->SetInput("Y", this->Output("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput("Scale", this->Input("Scale")); op->SetInput("Bias", this->Input("Bias")); op->SetInput("SavedMean", this->Output("SavedMean")); op->SetInput("SavedVariance", this->Output("SavedVariance")); if (this->HasOutput("ReserveSpace")) { op->SetInput("ReserveSpace", this->Output("ReserveSpace")); } // used when setting use_global_stats True during training if (PADDLE_GET_CONST(bool, this->GetAttr("use_global_stats"))) { op->SetInput("Mean", this->Output("MeanOut")); op->SetInput("Variance", this->Output("VarianceOut")); } op->SetAttrMap(this->Attrs()); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); } }; template class InplaceABNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Output("Y"); PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( "X and Y not inplaced in inplace mode")); auto activation = GetInplaceABNActivationType(ctx.Attr("activation")); auto& place = *ctx.template device_context().eigen_device(); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); auto* mean = ctx.Input("Mean"); auto* variance = ctx.Input("Variance"); auto momentum = ctx.Attr("momentum"); auto epsilon = ctx.Attr("epsilon"); auto data_layout = ctx.Attr("data_layout"); auto is_test = ctx.Attr("is_test"); auto use_global_stats = ctx.Attr("use_global_stats"); auto trainable_statistics = ctx.Attr("trainable_statistics"); auto* mean_out = ctx.Output("MeanOut"); auto* variance_out = ctx.Output("VarianceOut"); auto* saved_mean = ctx.Output("SavedMean"); auto* saved_variance = ctx.Output("SavedVariance"); auto* reserve_space = ctx.Output("ReserveSpace"); auto& dev_ctx = ctx.device_context(); phi::BatchNormKernel( static_cast::TYPE&>(dev_ctx), *x, *mean, *variance, *scale, *bias, is_test, momentum, epsilon, data_layout, use_global_stats, trainable_statistics, y, mean_out, variance_out, saved_mean, saved_variance, reserve_space); auto cur_y = EigenVector::Flatten(*y); InplaceABNActivation functor; functor.Compute(ctx, activation, place, cur_y, cur_y); } }; template class InplaceABNGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* y = ctx.Input("Y"); auto* d_y = ctx.Input(framework::GradVarName("Y")); auto* d_x = ctx.Output(framework::GradVarName("X")); PADDLE_ENFORCE_EQ(d_x, d_y, platform::errors::InvalidArgument( "X@GRAD and Y@GRAD not inplaced in inplace mode")); auto& place = *ctx.template device_context().eigen_device(); auto activation = GetInplaceABNActivationType(ctx.Attr("activation")); auto py = *y; auto pd_y = *d_y; auto cur_y = EigenVector::Flatten(py); auto cur_dy = EigenVector::Flatten(pd_y); InplaceABNActivation functor; functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy); // BatchNormGradKernel::Compute(ctx); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); auto* saved_mean = ctx.Input("SavedMean"); auto* saved_variance = ctx.Input("SavedVariance"); auto momentum = ctx.Attr("momentum"); auto epsilon = ctx.Attr("epsilon"); auto data_layout = ctx.Attr("data_layout"); auto is_test = ctx.Attr("is_test"); auto use_global_stats = ctx.Attr("use_global_stats"); auto trainable_statistics = ctx.Attr("trainable_statistics"); auto* scale_grad = ctx.Output(framework::GradVarName("Scale")); auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); auto* reserve_space = ctx.Input("ReserveSpace"); auto* mean = ctx.Input("ReserveSpace"); auto* variance = ctx.Input("ReserveSpace"); paddle::optional space_opt; paddle::optional mean_opt; paddle::optional variance_opt; if (reserve_space != nullptr) { space_opt = *reserve_space; } if (mean != nullptr) { mean_opt = *mean; } if (variance != nullptr) { variance_opt = *variance; } auto& dev_ctx = ctx.device_context(); phi::BatchNormGradRawKernel( static_cast::TYPE&>(dev_ctx), *y, *scale, *bias, mean_opt, variance_opt, *saved_mean, *saved_variance, space_opt, *d_y, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics, true, d_x, scale_grad, bias_grad); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; DECLARE_INPLACE_OP_INFERER(InplaceAbnOpInplaceInferer, {"X", "Y"}); REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker, ops::BatchNormOpInferVarType, ops::InplaceABNOpGradMaker, ops::InplaceABNOpGradMaker, InplaceAbnOpInplaceInferer) REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp) PD_REGISTER_STRUCT_KERNEL( inplace_abn, CPU, ALL_LAYOUT, ops::InplaceABNKernel, float, double) {} PD_REGISTER_STRUCT_KERNEL(inplace_abn_grad, CPU, ALL_LAYOUT, ops::InplaceABNGradKernel, float, double) {}