提交 40aee48a 编写于 作者: C caoying03

follow comments.

上级 3d77360b
......@@ -23,11 +23,6 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>(
"softLabel",
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
......@@ -49,6 +44,11 @@ class SoftmaxWithCrossEntropyOpMaker
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
"entropy loss with shape [N x 1].");
AddAttr<bool>(
"softLabel",
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
......@@ -95,10 +95,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(
logits->dims().size() == 2UL,
PADDLE_ENFORCE_EQ(
logits->dims().size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
......@@ -106,7 +106,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
"If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
......@@ -130,13 +130,13 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null.");
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
......@@ -144,7 +144,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
"When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册