提交 325ee637 编写于 作者: Q qiaolongfei

fix SoftmaxWithCrossEntropyOp

上级 dcfd31d7
......@@ -18,7 +18,7 @@ namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>;
template class SoftmaxFunctor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
......
......@@ -82,40 +82,38 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
PADDLE_ENFORCE(ctx->HasOutput("Softmax"),
"Output(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
"Output(Loss) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(
logits->dims().size(), 2UL,
logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
ctx->SetOutputDim("Softmax", logits_dims);
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
ctx.ShareLoD("Logits", /*->*/ "Softmax");
ctx.ShareLoD("Logits", /*->*/ "Loss");
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
}
};
......@@ -124,33 +122,32 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
PADDLE_ENFORCE(ctx->HasInput("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null.");
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
if (ctx->Attrs().Get<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
->Resize(ctx.Input<Tensor>("Softmax")->dims());
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax"));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册