From e7e7cb5f5e5dc52726da261d79613f73f5cf4bf7 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 13 Apr 2020 21:51:59 +0800 Subject: [PATCH] Split inplace_abn & batch_norm infershape (#23755) * Fix elementwise compile error, test=develop * split inplace_abn & batch_norm InferShape. test=develop * fix type. test=develop * fix message. test=develop * fix ENFORCE. test=develop Co-authored-by: zhaoyuchen --- paddle/fluid/operators/batch_norm_op.cc | 37 +++++++---------- paddle/fluid/operators/inplace_abn_op.cc | 52 ++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index ec9f9e9c7dc..9aeccde1256 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -468,30 +468,21 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { "in gradient op kernel of batch_norm_mkldnn_op now.")); } - // batch_norm_grad with inplace takes Y as input, without inplace - // takes X as input. HasInput will throw exception in compile time, - // so only infer shape in run time here. - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(ctx->HasInput("X") || ctx->HasInput("Y"), true, - platform::errors::NotFound( - "Input(X) and Input(Y) should not be all null.")); - auto input_name = "Y"; - if (ctx->HasInput("X")) input_name = "X"; - const auto x_dims = ctx->GetInputDim(input_name); - const DataLayout data_layout = framework::StringToDataLayout( - ctx->Attrs().Get("data_layout")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad"); + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); - const int C = - ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) - ? x_dims[1] - : x_dims[x_dims.size() - 1]); - - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - // has_scale_grad == has_bias_grad, judge has_scale_grad is enough - if (has_scale_grad) { - ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); - ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); - } + const int C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + // has_scale_grad == has_bias_grad, judge has_scale_grad is enough + if (has_scale_grad) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } } diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc index 0b65699348f..50e079672f8 100644 --- a/paddle/fluid/operators/inplace_abn_op.cc +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -62,6 +62,58 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { public: using paddle::operators::BatchNormGradOp::BatchNormGradOp; + void InferShape(framework::InferShapeContext* ctx) const { + // check input + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", + "Y@GRAD", "InplaceABNGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", + "InplaceABNGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", + "InplaceABNGrad"); + + // check output + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@GRAD", "InplaceABNGrad"); + + const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale")); + const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias")); + + PADDLE_ENFORCE_EQ( + has_scale_grad, has_bias_grad, + platform::errors::InvalidArgument( + "Output(Scale@GRAD) and Output(Bias@GRAD) must be null " + "or not be null at same time. But now, " + "has Scale@Grad=[%d], has Bias@GRAD=[%d]", + has_scale_grad, has_bias_grad)); + + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); + if (use_global_stats) { + PADDLE_ENFORCE_EQ( + !ctx->Attrs().Get("use_mkldnn"), true, + platform::errors::InvalidArgument( + "Using global stats during training is not supported " + "in gradient op kernel of batch_norm_mkldnn_op now.")); + } + + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad"); + const auto y_dims = ctx->GetInputDim("Y"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + const int C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? y_dims[1] + : y_dims[y_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), y_dims); + // has_scale_grad == has_bias_grad, judge has_scale_grad is enough + if (has_scale_grad) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); + } + } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { -- GitLab