diff --git a/paddle/fluid/operators/bce_loss_op.cc b/paddle/fluid/operators/bce_loss_op.cc index 4cbcd1dd7757f0130c8c50dc2dea3e95c3d6f064..50797a100b1a67244b7c7b40b47404b60dc6af65 100644 --- a/paddle/fluid/operators/bce_loss_op.cc +++ b/paddle/fluid/operators/bce_loss_op.cc @@ -27,15 +27,9 @@ class BCELossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Label"), true, - platform::errors::InvalidArgument("Input(Label) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument("Output(Out) should be not null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss"); auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); @@ -74,18 +68,12 @@ class BCELossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::InvalidArgument("Input(X) should be not null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Label"), true, - platform::errors::InvalidArgument("Input(Label) should be not null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::InvalidArgument( - "Input(Out@GRAD) shoudl be not null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, - platform::errors::InvalidArgument( - "Output(X@GRAD) should be not null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "BCELossGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "BCELossGrad"); auto x_dims = ctx->GetInputDim("X"); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); @@ -152,7 +140,6 @@ class BCELossGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Label", this->Input("Label")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - // op->SetAttrMap(this->Attrs()); } }; diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index f3351e36a694dbe12dc5c0f8a361ac5508b72992..738d30ad93804a11b34022de794a2a2d5f0c4fa4 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -131,5 +131,15 @@ class TestBceLossOp(OpTest): self.shape = [10, 10] +class TestBceLossOpCase1(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 4, 5] + + +class TestBceLossOpCase2(OpTest): + def init_test_cast(self): + self.shape = [2, 3, 20] + + if __name__ == "__main__": unittest.main()