From 25ef38bc0537b6230fb1e830a7429e250f325b43 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Sun, 12 Apr 2020 13:15:24 +0800 Subject: [PATCH] Fix bce_loss (#23746) * fix bce_loss,test=develop --- paddle/fluid/operators/bce_loss_op.cc | 31 ++++++------------- .../fluid/tests/unittests/test_bce_loss.py | 10 ++++++ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 4cbcd1dd775..50797a100b1 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -27,15 +27,9 @@ class BCELossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Label"), true, - platform::errors::InvalidArgument("Input(Label) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument("Output(Out) should be not null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss"); auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); @@ -74,18 +68,12 @@ class BCELossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Label"), true, - platform::errors::InvalidArgument("Input(Label) should be not null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::InvalidArgument( - "Input(Out@GRAD) shoudl be not null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, - platform::errors::InvalidArgument( - "Output(X@GRAD) should be not null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "BCELossGrad"); auto x_dims = ctx->GetInputDim("X"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); @@ -152,7 +140,6 @@ class BCELossGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Label", this->Input("Label")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - // op->SetAttrMap(this->Attrs()); } }; diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index f3351e36a69..738d30ad938 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -131,5 +131,15 @@ class TestBceLossOp(OpTest): self.shape = [10, 10] +class TestBceLossOpCase1(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 4, 5] + + +class TestBceLossOpCase2(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 20] + + if __name__ == "__main__": unittest.main() -- GitLab