未验证 提交 e7e7cb5f 编写于 作者: K Kaipeng Deng 提交者: GitHub

Split inplace_abn & batch_norm infershape (#23755)

* Fix elementwise compile error, test=develop

* split inplace_abn & batch_norm InferShape. test=develop

* fix type. test=develop

* fix message. test=develop

* fix ENFORCE. test=develop
Co-authored-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 afb2cb7c
......@@ -468,30 +468,21 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"in gradient op kernel of batch_norm_mkldnn_op now."));
}
// batch_norm_grad with inplace takes Y as input, without inplace
// takes X as input. HasInput will throw exception in compile time,
// so only infer shape in run time here.
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(ctx->HasInput("X") || ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(X) and Input(Y) should not be all null."));
auto input_name = "Y";
if (ctx->HasInput("X")) input_name = "X";
const auto x_dims = ctx->GetInputDim(input_name);
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormGrad");
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
}
......
......@@ -62,6 +62,58 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
public:
using paddle::operators::BatchNormGradOp::BatchNormGradOp;
void InferShape(framework::InferShapeContext* ctx) const {
// check input
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
"Y@GRAD", "InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
"InplaceABNGrad");
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
"InplaceABNGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
"X@GRAD", "InplaceABNGrad");
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ(
has_scale_grad, has_bias_grad,
platform::errors::InvalidArgument(
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
"or not be null at same time. But now, "
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
has_scale_grad, has_bias_grad));
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
if (use_global_stats) {
PADDLE_ENFORCE_EQ(
!ctx->Attrs().Get<bool>("use_mkldnn"), true,
platform::errors::InvalidArgument(
"Using global stats during training is not supported "
"in gradient op kernel of batch_norm_mkldnn_op now."));
}
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad");
const auto y_dims = ctx->GetInputDim("Y");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? y_dims[1]
: y_dims[y_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if (has_scale_grad) {
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册