未验证 提交 e7e7cb5f 编写于 作者: K Kaipeng Deng 提交者: GitHub

Split inplace_abn & batch_norm infershape (#23755)

* Fix elementwise compile error, test=develop

* split inplace_abn & batch_norm InferShape. test=develop

* fix type. test=develop

* fix message. test=develop

* fix ENFORCE. test=develop
Co-authored-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 afb2cb7c
...@@ -468,16 +468,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -468,16 +468,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"in gradient op kernel of batch_norm_mkldnn_op now.")); "in gradient op kernel of batch_norm_mkldnn_op now."));
} }
// batch_norm_grad with inplace takes Y as input, without inplace OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad");
// takes X as input. HasInput will throw exception in compile time, const auto x_dims = ctx->GetInputDim("X");
// so only infer shape in run time here.
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->HasInput("X") || ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(X) and Input(Y) should not be all null."));
auto input_name = "Y";
if (ctx->HasInput("X")) input_name = "X";
const auto x_dims = ctx->GetInputDim(input_name);
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
...@@ -492,7 +484,6 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -492,7 +484,6 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
} }
}
} }
framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
......
...@@ -62,6 +62,58 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { ...@@ -62,6 +62,58 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
public: public:
using paddle::operators::BatchNormGradOp::BatchNormGradOp; using paddle::operators::BatchNormGradOp::BatchNormGradOp;
void InferShape(framework::InferShapeContext* ctx) const {
// check input
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
"Y@GRAD", "InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"InplaceABNGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@GRAD", "InplaceABNGrad");
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ(
has_scale_grad, has_bias_grad,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
"or not be null at same time. But now, "
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
has_scale_grad, has_bias_grad));
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE_EQ(
!ctx->Attrs().Get<bool>("use_mkldnn"), true,
platform::errors::InvalidArgument(
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now."));
}
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad");
const auto y_dims = ctx->GetInputDim("Y");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? y_dims[1]
: y_dims[y_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册