提交 0dce16a6 编写于 作者: D dangqingqing

Use bool type for attr in cross_entropy_op.

上级 58e3ad0a
...@@ -35,19 +35,16 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -35,19 +35,16 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) must "
"be equal."); "be equal.");
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of " "If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) must be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of " "If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) must be 1.");
} }
...@@ -74,9 +71,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -74,9 +71,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) must "
"be equal."); "be equal.");
...@@ -85,13 +79,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -85,13 +79,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1."); "The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of " "If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) must be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of " "If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) must be 1.");
} }
...@@ -108,7 +102,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -108,7 +102,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of CrossEntropyOp"); AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp"); AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp"); AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0); AddAttr<bool>("soft_label", "Is soft label. Default zero.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. CrossEntropy Operator.
...@@ -116,12 +111,12 @@ CrossEntropy Operator. ...@@ -116,12 +111,12 @@ CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss It supports both standard cross-entropy and soft-label cross-entropy loss
computation. computation.
1) One-hot cross-entropy: 1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i: soft_label = False, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]]) Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy: 2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j soft_label = True, Label[i, j] indicates the soft label of class j
for sample i: for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
......
...@@ -102,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -102,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream // TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d); d);
...@@ -137,7 +137,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -137,7 +137,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
grid = (n + block - 1) / block; grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream // TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>( SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d); dx_data, dy_data, x_data, label_data, n, d);
......
...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -92,7 +92,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -92,7 +92,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
int class_num = x->dims()[1]; int class_num = x->dims()[1];
// TODO(qingqing): make zero setting an common function. // TODO(qingqing): make zero setting an common function.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype="float32") dtype="float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0} self.attrs = {'soft_label': False}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -45,7 +45,7 @@ class TestCrossEntropyOp2(OpTest): ...@@ -45,7 +45,7 @@ class TestCrossEntropyOp2(OpTest):
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy} self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1} self.attrs = {'soft_label': True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -76,7 +76,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -76,7 +76,7 @@ class TestCrossEntropyOp3(OpTest):
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy} self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1} self.attrs = {'soft_label': True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册