提交 325ee637 编写于 作者: Q qiaolongfei

fix SoftmaxWithCrossEntropyOp

上级 dcfd31d7
...@@ -18,7 +18,7 @@ namespace paddle { ...@@ -18,7 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>; template class SoftmaxFunctor<platform::CPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"), PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null."); "Input(Logits) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"), "Output(Softmax) should be not null.");
"Output(Softmax) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
"Output(Loss) should be not null."); auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
logits->dims().size(), 2UL, logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor."); "The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL, PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor."); "The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1], PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of " "If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(softLabel) == false, the 2nd dimension of " "If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims()); ctx->SetOutputDim("Softmax", logits_dims);
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1}); ctx->SetOutputDim("Loss", {logits_dims[0], 1});
ctx.ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx.ShareLoD("Logits", /*->*/ "Loss"); ctx->ShareLoD("Logits", /*->*/ "Loss");
} }
}; };
...@@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null."); "Input(Loss@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"), PADDLE_ENFORCE(ctx->HasInput("Softmax"),
"Input(Softmax) should be not null."); "Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
"Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")), "Output(Logits@Grad) should be not null.");
"Output(Logits@Grad) should be not null.");
auto softmax_dims = ctx->GetInputDim("Softmax");
const Tensor* softmax = ctx.Input<Tensor>("Softmax"); auto labels_dims = ctx->GetInputDim("Label");
const Tensor* labels = ctx.Input<Tensor>("Label"); PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
"The labels should be a 2-D tensor."); "The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) { if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1], PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of " "When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL, PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(softLabel) == false, the 2nd dimension of " "When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits")) ctx->SetOutputDim(framework::GradVarName("Logits"),
->Resize(ctx.Input<Tensor>("Softmax")->dims()); ctx->GetInputDim("Softmax"));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册