From 325ee63746917a3fe8268d9934e556b9fae5339e Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 26 Sep 2017 22:29:11 -0700 Subject: [PATCH] fix SoftmaxWithCrossEntropyOp --- paddle/operators/math/softmax.cc | 2 +- .../softmax_with_cross_entropy_op.cc | 77 +++++++++---------- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index 1224c058105..ac9f3c4bf61 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { namespace math { -template class SoftmaxFunctor; +template class SoftmaxFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index b6f33ad9e03..e2299b25445 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"), - "Input(Logits) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"), - "Output(Softmax) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"), - "Output(Loss) should be not null."); - - const Tensor* logits = ctx.Input("Logits"); - const Tensor* labels = ctx.Input("Label"); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Logits"), + "Input(Logits) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + + PADDLE_ENFORCE(ctx->HasOutput("Softmax"), + "Output(Softmax) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null."); + + auto logits_dims = ctx->GetInputDim("Logits"); + auto labels_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ( - logits->dims().size(), 2UL, + logits_dims.size(), 2UL, "The input of softmax_with_cross_entropy should be a 2-D tensor."); - PADDLE_ENFORCE_EQ(ctx.Input("Label")->dims().size(), 2UL, + PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL, "The labels should be a 2-D tensor."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1], "If Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, + PADDLE_ENFORCE_EQ(labels_dims[1], 1UL, "If Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - ctx.Output("Softmax")->Resize(logits->dims()); - ctx.Output("Loss")->Resize({logits->dims()[0], 1}); + ctx->SetOutputDim("Softmax", logits_dims); + ctx->SetOutputDim("Loss", {logits_dims[0], 1}); - ctx.ShareLoD("Logits", /*->*/ "Softmax"); - ctx.ShareLoD("Logits", /*->*/ "Loss"); + ctx->ShareLoD("Logits", /*->*/ "Softmax"); + ctx->ShareLoD("Logits", /*->*/ "Loss"); } }; @@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), - "Input(Loss@Grad) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), - "Input(Softmax) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) should be not null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")), - "Output(Logits@Grad) should be not null."); - - const Tensor* softmax = ctx.Input("Softmax"); - const Tensor* labels = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(ctx.Input("Label")->dims().size(), 2UL, + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), + "Input(Loss@Grad) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Softmax"), + "Input(Softmax) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), + "Output(Logits@Grad) should be not null."); + + auto softmax_dims = ctx->GetInputDim("Softmax"); + auto labels_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL, "The labels should be a 2-D tensor."); - if (ctx.Attr("softLabel")) { - PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1], + if (ctx->Attrs().Get("softLabel")) { + PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1], "When Attr(softLabel) == true, the 2nd dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, + PADDLE_ENFORCE_EQ(labels_dims[1], 1UL, "When Attr(softLabel) == false, the 2nd dimension of " "Input(Label) should be 1."); } - ctx.Output(framework::GradVarName("Logits")) - ->Resize(ctx.Input("Softmax")->dims()); + ctx->SetOutputDim(framework::GradVarName("Logits"), + ctx->GetInputDim("Softmax")); } }; -- GitLab