未验证 提交 25ef38bc 编写于 作者: C ceci3 提交者: GitHub

Fix bce_loss (#23746)

* fix bce_loss,test=develop
上级 cd1de0e2
...@@ -27,15 +27,9 @@ class BCELossOp : public framework::OperatorWithKernel { ...@@ -27,15 +27,9 @@ class BCELossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELoss");
ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELoss");
platform::errors::InvalidArgument("Input(X) should be not null.")); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss");
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument("Output(Out) should be not null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
...@@ -74,18 +68,12 @@ class BCELossGradOp : public framework::OperatorWithKernel { ...@@ -74,18 +68,12 @@ class BCELossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BCELossGrad");
ctx->HasInput("X"), true, OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "BCELossGrad");
platform::errors::InvalidArgument("Input(X) should be not null.")); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
PADDLE_ENFORCE_EQ( framework::GradVarName("Out"), "BCELossGrad");
ctx->HasInput("Label"), true, OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
platform::errors::InvalidArgument("Input(Label) should be not null.")); framework::GradVarName("X"), "BCELossGrad");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) shoudl be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
...@@ -152,7 +140,6 @@ class BCELossGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -152,7 +140,6 @@ class BCELossGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
// op->SetAttrMap(this->Attrs());
} }
}; };
......
...@@ -131,5 +131,15 @@ class TestBceLossOp(OpTest): ...@@ -131,5 +131,15 @@ class TestBceLossOp(OpTest):
self.shape = [10, 10] self.shape = [10, 10]
class TestBceLossOpCase1(OpTest):
def init_test_cast(self):
self.shape = [2, 3, 4, 5]
class TestBceLossOpCase2(OpTest):
def init_test_cast(self):
self.shape = [2, 3, 20]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册