diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 61d2104b95c9ce826dbd29bc5b4f241aa2d9761e..953367eb8bcd1282ab6c7e1189d778f0ce3da541 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -25,25 +25,32 @@ class CrossEntropyOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CrossEntropyOp must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) of CrossEntropyOp must not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of CrossEntropyOp must not be null."); - - auto *x = ctx.Input("X"); - auto *label = ctx.Input("Label"); - - PADDLE_ENFORCE_EQ(x->dims().size(), 2, "X's rank must be 2."); - PADDLE_ASSERT(label->dims().size() == 1 || label->dims().size() == 2); - if (label->dims().size() == 2) { - // soft cross entropy - PADDLE_ENFORCE_EQ(x->dims(), label->dims()); + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null."); + + auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); } else { - // normal cross entropy - PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]); + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); } + ctx.Output("Y")->Resize({x->dims()[0], 1}); } }; @@ -54,12 +61,41 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of CrossEntropyOp must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), + "Input(Y@GRAD) must not be null."); - auto dx = ctx.Output(framework::GradVarName("X")); auto x = ctx.Input("X"); + auto label = ctx.Input("Label"); + auto dy = ctx.Input(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 2, + "Input(Label)'s rank must be 2."); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("soft_label") == 0 || + ctx.Attr("soft_label") == 1); + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], + "The 1st dimension of Input(X) and Input(Label) must " + "be equal."); + PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0], + "The 1st dimension of Input(X) and Input(Y@Grad) must " + "be equal."); + PADDLE_ENFORCE_EQ(dy->dims()[1], 1, + "The 2nd dimension of Input(Y@Grad) must be 1."); + if (ctx.Attr("soft_label") == 1) { + PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], + "If Attr(soft_label) == 1, The 2nd dimension of " + "Input(X) and Input(Label) must be equal."); + } else { + PADDLE_ENFORCE_EQ(label->dims()[1], 1, + "If Attr(soft_label) == 0, The 2nd dimension of " + "Input(Label) must be 1."); + } + auto dx = ctx.Output(framework::GradVarName("X")); dx->Resize(x->dims()); } }; @@ -72,22 +108,31 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of CrossEntropyOp"); AddInput("Label", "The second input of CrossEntropyOp"); AddOutput("Y", "The output of CrossEntropyOp"); + AddAttr("soft_label", "Is soft label. Default zero.").SetDefault(0); + AddComment(R"DOC( CrossEntropy Operator. -The second input (Label tensor) supports two kinds of shapes: -1) Rank(Label) = 1, Label[i] indicates the class index for sample i: +It supports both standard cross-entropy and soft-label cross-entropy loss +computation. +1) One-hot cross-entropy: + soft_label = 0, Label[i, 0] indicates the class index for sample i: Y[i] = -log(X[i, Label[i]]) -2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j - for sample i: +2) Soft-label cross-entropy: + soft_label = 1, Label[i, j] indicates the soft label of class j + for sample i: Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Please make sure that in this case the summuation of each row of Label - equals one. If each row of Label has only one non-zero element (equals 1), - it degenerates to a standard one-hot representation. + equals one. + +3) One-hot cross-entropy with vecterized Input(Label): + As a special case of 2), when each row of Input(Label) has only one + non-zero element (equals 1), soft-label cross-entropy degenerates to a + one-hot cross-entropy with one-hot label representation. )DOC"); } }; diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index e80dcec8e2511242b73290147820eb88fa5aae06..ab6ad0e062269483948bf70e492c9431991221fb 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -13,27 +13,13 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" +#include "paddle/operators/cross_entropy_op.h" #include "paddle/platform/assert.h" #include "paddle/platform/hostdevice.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; - -template -HOSTDEVICE T tolerable_value(const T x) { - PADDLE_ASSERT(std::is_floating_point::value); - const T kApproInf = 1e20; - if (x == INFINITY) { - return kApproInf; - } - if (x == -INFINITY) { - return -kApproInf; - } - return x; -} - template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { @@ -53,9 +39,9 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, i += blockDim.x * gridDim.x) { T sum = static_cast(0); for (int j = 0; j < D; j++) { - sum += label[i * D + j] * log(X[i * D + j]); + sum += label[i * D + j] * tolerable_value(log(X[i * D + j])); } - Y[i] = -tolerable_value(sum); + Y[i] = -sum; } } @@ -85,6 +71,7 @@ template __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const T* label, const int N, const int D) { + // TOOD(qingqing): optimize for this kernel for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { for (int j = 0; j < D; ++j) { @@ -115,14 +102,11 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { int grid = (n + block - 1) / block; // TODO(qingqing) launch kernel on specified stream // base on ExecutionContext. - int label_rank = label->dims().size(); - if (label_rank == 2) { - // soft cross entropy + if (ctx.Attr("soft_label") == 1) { auto* label_data = ctx.Input("Label")->data(); SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); } else { - // normal cross entropy auto* label_data = ctx.Input("Label")->data(); CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); } @@ -153,14 +137,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { grid = (n + block - 1) / block; // TODO(qingqing): launch kernel on specified stream // base on ExecutionContext. - int label_rank = label->dims().size(); - if (label_rank == 2) { - // soft cross entropy + if (ctx.Attr("soft_label") == 1) { auto* label_data = label->data(); SoftCrossEntropyGradientKernel<<>>( dx_data, dy_data, x_data, label_data, n, d); } else { - // normal cross entropy auto* label_data = label->data(); CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, label_data, n, d); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 9a661cb9cf2d757723787f721f13b7b570f6e2d6..1b4b23ac2029138afadef0168262203ac2e20430 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" namespace paddle { namespace operators { @@ -21,21 +22,15 @@ namespace operators { using Tensor = framework::Tensor; template -inline T tolerable_value(const T x) { - static_assert(std::is_floating_point::value, - "tolerable_value works only on float, " - "double and double double."); - +HOSTDEVICE T tolerable_value(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); const T kApproInf = 1e20; - if (x == INFINITY) { return kApproInf; } - if (x == -INFINITY) { return -kApproInf; } - return x; } @@ -55,22 +50,19 @@ class CrossEntropyOpKernel : public framework::OpKernel { int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - int label_rank = ctx.Input("Label")->dims().size(); - if (label_rank == 2) { - // soft cross entropy + if (ctx.Attr("soft_label") == 1) { auto* label_data = ctx.Input("Label")->data(); int index = 0; for (int i = 0; i < batch_size; ++i) { T sum = static_cast(0); for (int j = 0; j < class_num; ++j) { - sum += label_data[index] * std::log(x_data[index]); - y_data[i] = -tolerable_value(sum); + sum += label_data[index] * tolerable_value(std::log(x_data[index])); + y_data[i] = -sum; index++; } } } else { - // normal cross entropy auto* label_data = ctx.Input("Label")->data(); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; @@ -98,11 +90,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { int batch_size = x->dims()[0]; int class_num = x->dims()[1]; - int label_rank = ctx.Input("Label")->dims().size(); // TODO(qingqing): make zero setting an common function. - if (label_rank == 2) { - // soft cross entropy + if (ctx.Attr("soft_label") == 1) { auto* label_data = ctx.Input("Label")->data(); int index = 0; for (int i = 0; i < batch_size; ++i) { @@ -112,7 +102,6 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { } } } else { - // normal cross entropy auto* label_data = label->data(); memset(dx_data, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index ccff2a386d368b7e0b5f0f5a2a5029ccd48d3d42..0206ca064be87afe204aa99021979b7ddc3c5d63 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -1,23 +1,25 @@ import unittest -import numpy +import numpy as np from op_test import OpTest -class TestOnehotCrossEntropyOp(OpTest): +class TestCrossEntropyOp1(OpTest): + """Test standard cross-entropy, with index representation of labels. + """ + def setUp(self): self.op_type = "cross_entropy" batch_size = 30 class_num = 10 - - X = numpy.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float32") - labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") - - cross_entropy = numpy.asmatrix( - [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])], + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") + cross_entropy = np.asmatrix( + [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], dtype="float32") - self.inputs = {"X": X, "Label": labels} + self.inputs = {"X": X, "Label": label} self.outputs = {"Y": cross_entropy} + self.attrs = {'soft_label': 0} def test_check_output(self): self.check_output() @@ -26,20 +28,55 @@ class TestOnehotCrossEntropyOp(OpTest): self.check_grad(["X"], "Y") -class TestCrossEntropySoftLabel(OpTest): +class TestCrossEntropyOp2(OpTest): + """Test soft-label cross-entropy, with vecterized soft labels. + """ + def setUp(self): self.op_type = "cross_entropy" - batch_size = 30 - class_num = 10 - X = numpy.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float32") - label = numpy.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float32") + batch_size = 10 + class_num = 5 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") label /= label.sum(axis=1, keepdims=True) + cross_entropy = (-label * np.log(X)).sum( + axis=1, keepdims=True).astype("float32") self.inputs = {'X': X, 'Label': label} - cross_entropy = (-label * numpy.log(X)).sum( + self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + +class TestCrossEntropyOp3(OpTest): + """Test one-hot cross-entropy, with vecterized one-hot representation of + labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label_index = np.random.randint( + 0, class_num, (batch_size), dtype="int32") + label = np.zeros(X.shape) + label[np.arange(batch_size), label_index] = 1 + cross_entropy = np.asmatrix( + [[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], + dtype="float32") + cross_entropy2 = (-label * np.log(X)).sum( axis=1, keepdims=True).astype("float32") + self.inputs = {'X': X, 'Label': label} self.outputs = {'Y': cross_entropy} + self.attrs = {'soft_label': 1} def test_check_output(self): self.check_output()