From 0dce16a697713b848d2afa5d8a5ee3b8108b150a Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Fri, 22 Sep 2017 10:20:02 +0800 Subject: [PATCH] Use bool type for attr in cross_entropy_op. --- paddle/operators/cross_entropy_op.cc | 25 ++++++++----------- paddle/operators/cross_entropy_op.cu | 4 +-- paddle/operators/cross_entropy_op.h | 4 +-- .../framework/tests/test_cross_entropy_op.py | 6 ++--- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 953367eb8b..559fc5a8d7 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -35,19 +35,16 @@ class CrossEntropyOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, "Input(Label)'s rank must be 2."); - // TODO(xinghai-sun): remove this check after swtiching to bool - PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || - ctx.Attr("soft_label") == 1); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], "The 1st dimension of Input(X) and Input(Label) must " "be equal."); - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == 1, The 2nd dimension of " + "If Attr(soft_label) == true, The 2nd dimension of " "Input(X) and Input(Label) must be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == 0, The 2nd dimension of " + "If Attr(soft_label) == false, The 2nd dimension of " "Input(Label) must be 1."); } @@ -74,9 +71,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 2, "Input(Label)'s rank must be 2."); - // TODO(xinghai-sun): remove this check after swtiching to bool - PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || - ctx.Attr("soft_label") == 1); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], "The 1st dimension of Input(X) and Input(Label) must " "be equal."); @@ -85,13 +79,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { "be equal."); PADDLE_ENFORCE_EQ(dy->dims()[1], 1, "The 2nd dimension of Input(Y@Grad) must be 1."); - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], - "If Attr(soft_label) == 1, The 2nd dimension of " + "If Attr(soft_label) == true, The 2nd dimension of " "Input(X) and Input(Label) must be equal."); } else { PADDLE_ENFORCE_EQ(label->dims()[1], 1, - "If Attr(soft_label) == 0, The 2nd dimension of " + "If Attr(soft_label) == false, The 2nd dimension of " "Input(Label) must be 1."); } @@ -108,7 +102,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of CrossEntropyOp"); AddInput("Label", "The second input of CrossEntropyOp"); AddOutput("Y", "The output of CrossEntropyOp"); - AddAttr("soft_label", "Is soft label. Default zero.").SetDefault(0); + AddAttr("soft_label", "Is soft label. Default zero.") + .SetDefault(false); AddComment(R"DOC( CrossEntropy Operator. @@ -116,12 +111,12 @@ CrossEntropy Operator. It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: - soft_label = 0, Label[i, 0] indicates the class index for sample i: + soft_label = False, Label[i, 0] indicates the class index for sample i: Y[i] = -log(X[i, Label[i]]) 2) Soft-label cross-entropy: - soft_label = 1, Label[i, j] indicates the soft label of class j + soft_label = True, Label[i, j] indicates the soft label of class j for sample i: Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index ab6ad0e062..1d6361a814 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -102,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { int grid = (n + block - 1) / block; // TODO(qingqing) launch kernel on specified stream // base on ExecutionContext. - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { auto* label_data = ctx.Input("Label")->data(); SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); @@ -137,7 +137,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { grid = (n + block - 1) / block; // TODO(qingqing): launch kernel on specified stream // base on ExecutionContext. - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { auto* label_data = label->data(); SoftCrossEntropyGradientKernel<<>>( dx_data, dy_data, x_data, label_data, n, d); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 1b4b23ac20..69caba5ff3 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { auto* label_data = ctx.Input("Label")->data(); int index = 0; for (int i = 0; i < batch_size; ++i) { @@ -92,7 +92,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { int class_num = x->dims()[1]; // TODO(qingqing): make zero setting an common function. - if (ctx.Attr("soft_label") == 1) { + if (ctx.Attr("soft_label")) { auto* label_data = ctx.Input("Label")->data(); int index = 0; for (int i = 0; i < batch_size; ++i) { diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 0206ca064b..f10db78322 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest): dtype="float32") self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} - self.attrs = {'soft_label': 0} + self.attrs = {'soft_label': False} def test_check_output(self): self.check_output() @@ -45,7 +45,7 @@ class TestCrossEntropyOp2(OpTest): axis=1, keepdims=True).astype("float32") self.inputs = {'X': X, 'Label': label} self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': 1} + self.attrs = {'soft_label': True} def test_check_output(self): self.check_output() @@ -76,7 +76,7 @@ class TestCrossEntropyOp3(OpTest): axis=1, keepdims=True).astype("float32") self.inputs = {'X': X, 'Label': label} self.outputs = {'Y': cross_entropy} - self.attrs = {'soft_label': 1} + self.attrs = {'soft_label': True} def test_check_output(self): self.check_output() -- GitLab