// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/inplace_abn_op.h" #include #include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/operators/batch_norm_op.h" namespace paddle { namespace operators { class InplaceABNOp : public paddle::operators::BatchNormOp { public: using paddle::operators::BatchNormOp::BatchNormOp; protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). auto bn_param_type = framework::proto::VarType::FP32; if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), platform::errors::InvalidArgument( "Scale input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), platform::errors::InvalidArgument( "Bias input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), platform::errors::InvalidArgument( "Mean input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), platform::errors::InvalidArgument( "Variance input should be of float type")); framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, library); } }; class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { public: using paddle::operators::BatchNormGradOp::BatchNormGradOp; protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { const auto* var = ctx.InputVar(framework::GradVarName("Y")); auto input_data_type = ctx.Input("Y")->type(); if (var == nullptr) { PADDLE_THROW(platform::errors::InvalidArgument( "can't find gradient variable of Y")); } const Tensor* t = nullptr; if (var->IsType()) { t = &var->Get(); } else if (var->IsType()) { t = &var->Get(); } if (t == nullptr) { PADDLE_THROW( platform::errors::InvalidArgument("gradient variable of Y is empty")); } framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, library); } }; class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker { public: void Make() override { BatchNormOpMaker::Make(); AddAttr( "activation", "(enum string, default identity, can be identity|elu|leaky-relu) " "The activation type used for output candidate {h}_t.") .SetDefault(""); AddAttr("alpha", "(float, default 1.0) Only used in inplace-abn kernel," "the activation type(identity|elu|leakyrelu) would be fused " "with batch_norm, " "this is the alpha value for elu|leakyrelu.") .SetDefault(0.1f); AddAttr("use_sync_bn", "(bool, default false) Whether use synchronize batch " "normalization.") .SetDefault(false); } }; template class InplaceABNOpGradMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; protected: void Apply(GradOpPtr op) const override { op->SetType(this->ForwardOpType() + "_grad"); op->SetInput("Y", this->Output("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput("Scale", this->Input("Scale")); op->SetInput("Bias", this->Input("Bias")); op->SetInput("SavedMean", this->Output("SavedMean")); op->SetInput("SavedVariance", this->Output("SavedVariance")); // used when setting use_global_stats True during training if (boost::get(this->GetAttr("use_global_stats"))) { op->SetInput("Mean", this->Output("MeanOut")); op->SetInput("Variance", this->Output("VarianceOut")); } op->SetAttrMap(this->Attrs()); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); } }; template class InplaceABNKernel : public paddle::operators::BatchNormKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto* y = ctx.Output("Y"); PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( "X and Y not inplaced in inplace mode")); auto activation = GetInplaceABNActivationType(ctx.Attr("activation")); auto& place = *ctx.template device_context().eigen_device(); BatchNormKernel::Compute(ctx); auto cur_y = EigenVector::Flatten(*y); InplaceABNActivation functor; functor.Compute(ctx, activation, place, cur_y, cur_y); } }; template class InplaceABNGradKernel : public paddle::operators::BatchNormGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* y = ctx.Input("Y"); auto* d_y = ctx.Input(framework::GradVarName("Y")); auto* d_x = ctx.Output(framework::GradVarName("X")); PADDLE_ENFORCE_EQ(d_x, d_y, platform::errors::InvalidArgument( "X@GRAD and Y@GRAD not inplaced in inplace mode")); auto& place = *ctx.template device_context().eigen_device(); auto activation = GetInplaceABNActivationType(ctx.Attr("activation")); auto py = *y; auto pd_y = *d_y; auto cur_y = EigenVector::Flatten(py); auto cur_dy = EigenVector::Flatten(pd_y); InplaceABNActivation functor; functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy); BatchNormGradKernel::Compute(ctx); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker, ops::BatchNormOpInferVarType, ops::InplaceABNOpGradMaker, ops::InplaceABNOpGradMaker) REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp) REGISTER_OP_CPU_KERNEL( inplace_abn, ops::InplaceABNKernel, ops::InplaceABNKernel); REGISTER_OP_CPU_KERNEL( inplace_abn_grad, ops::InplaceABNGradKernel, ops::InplaceABNGradKernel);