From 6d60352e7e5d4a01a61de395fc87438cf814b5c7 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 13 Sep 2017 22:28:29 +0800 Subject: [PATCH] Add soft-label support for cross-entropy operator. --- paddle/operators/cross_entropy_op.cc | 64 ++++++---- paddle/operators/cross_entropy_op.cu | 119 ++++++++++++------ paddle/operators/cross_entropy_op.h | 92 +++++++++----- paddle/pybind/pybind.cc | 2 +- .../framework/tests/test_cross_entropy_op.py | 25 +++- .../paddle/v2/framework/tests/test_mnist.py | 2 +- 6 files changed, 205 insertions(+), 99 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ab1e1c101a1..32ad0e82fa4 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -17,48 +17,62 @@ limitations under the License. */ namespace paddle { namespace operators { -class OnehotCrossEntropyOp : public framework::OperatorWithKernel { +class CrossEntropyOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto *X = ctx.Input("X"); - auto *label = ctx.Input("label"); + auto *x = ctx.Input("X"); + auto *label = ctx.Input("Label"); - PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); - PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); - PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); - ctx.Output("Y")->Resize({X->dims()[0]}); + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "X's rank must be 2."); + PADDLE_ASSERT(label->dims().size() == 1 || label->dims().size() == 2); + if (label->dims().size() == 2) { + // soft cross entropy + PADDLE_ENFORCE_EQ(x->dims(), label->dims()); + } else { + // normal cross entropy + PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]); + } + ctx.Output("Y")->Resize({x->dims()[0]}); } }; -class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { +class CrossEntropyGradientOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dX = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto x = ctx.Input("X"); - dX->Resize(X->dims()); + dx->Resize(x->dims()); } }; -class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { +class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: - OnehotCrossEntropyOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + CrossEntropyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of OnehotCrossEntropyOp"); - AddInput("label", "The second input of OnehotCrossEntropyOp"); - AddOutput("Y", "The output of OnehotCrossEntropyOp"); + AddInput("X", "The first input of CrossEntropyOp"); + AddInput("Label", "The second input of CrossEntropyOp"); + AddOutput("Y", "The output of CrossEntropyOp"); AddComment(R"DOC( -OnehotCrossEntropy Operator. +CrossEntropy Operator. - Y[i] = -log(X[i][j]) +The second input (Label tensor) supports two kinds of shapes: +1) Rank(Label) = 1, Label[i] indicates the class index for sample i: + Y[i] = -log(X[i, Label[i]]) +2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j + for sample i: + Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} + Please make sure that in this case the summuation of each row of Label + equals one. If each row of Label has only one non-zero element (equals 1), + it degenerates to a standard one-hot representation. )DOC"); } }; @@ -66,10 +80,8 @@ OnehotCrossEntropy Operator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); +REGISTER_OP(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, + cross_entropy_grad, ops::CrossEntropyGradientOp); +REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index d999bfce58c..1f5e9c1b04e 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -21,17 +21,16 @@ namespace operators { using Tensor = framework::Tensor; template -__host__ __device__ T clipping_log(const T x) { +__host__ __device__ T tolerable_value(const T x) { PADDLE_ASSERT(std::is_floating_point::value); const T kApproInf = 1e20; - T v = log(x); - if (v == INFINITY) { + if (x == INFINITY) { return kApproInf; } - if (v == -INFINITY) { + if (x == -INFINITY) { return -kApproInf; } - return v; + return x; } template @@ -42,7 +41,20 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -clipping_log(X[i * D + label[i]]); + Y[i] = -tolerable_value(log(X[i * D + label[i]])); + } +} + +template +__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, + const int N, const int D) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + T sum = static_cast(0); + for (int j = 0; j < D; j++) { + sum += label[i * D + j] * log(X[i * D + j]); + } + Y[i] = -tolerable_value(sum); } } @@ -69,57 +81,89 @@ __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, } template -class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { +__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const T* label, const int N, + const int D) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + for (int j = 0; j < D; ++j) { + int idx = i * D + j; + dX[idx] = -label[idx] * dY[i] / X[idx]; + } + } +} + +template +class CrossEntropyOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use GPUPlace."); - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - Y->mutable_data(ctx.GetPlace()); - T* Ydata = Y->data(); + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + auto label = ctx.Input("Label"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); - int N = X->dims()[0]; - int D = X->dims()[1]; + int n = x->dims()[0]; + int d = x->dims()[1]; int block = 512; - int grid = (N + block - 1) / block; + int grid = (n + block - 1) / block; // TODO(qingqing) launch kernel on specified stream // base on ExecutionContext. - CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + int label_rank = label->dims().size(); + if (label_rank == 2) { + // soft cross entropy + auto* label_data = ctx.Input("Label")->data(); + SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n, + d); + } else { + // normal cross entropy + auto* label_data = ctx.Input("Label")->data(); + CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d); + } } }; template -class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { +class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use GPUPlace."); - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); - int N = X->dims()[0]; - int D = X->dims()[1]; + int n = x->dims()[0]; + int d = x->dims()[1]; int block = 512; - int grid = (N * D + block - 1) / block; - zero<<>>(dXdata, N * D); - - grid = (N + block - 1) / block; + int grid = (n * d + block - 1) / block; + zero<<>>(dx_data, n * d); + grid = (n + block - 1) / block; // TODO(qingqing): launch kernel on specified stream // base on ExecutionContext. - CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, - label_data, N, D); + int label_rank = label->dims().size(); + if (label_rank == 2) { + // soft cross entropy + auto* label_data = label->data(); + SoftCrossEntropyGradientKernel<<>>( + dx_data, dy_data, x_data, label_data, n, d); + } else { + // normal cross entropy + auto* label_data = label->data(); + CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data, + label_data, n, d); + } } }; @@ -127,7 +171,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpCUDAKernel); -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(cross_entropy_grad, + ops::CrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index eb4d1348de1..9a661cb9cf2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -40,56 +40,86 @@ inline T tolerable_value(const T x) { } template -class OnehotCrossEntropyOpKernel : public framework::OpKernel { +class CrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto X = ctx.Input("X"); - const T* Xdata = X->data(); - const int* label_data = ctx.Input("label")->data(); - auto Y = ctx.Output("Y"); - - Y->mutable_data(ctx.GetPlace()); - - T* Ydata = Y->data(); - - int batch_size = X->dims()[0]; - int class_num = X->dims()[1]; - - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - Ydata[i] = -tolerable_value(std::log(Xdata[index])); + auto x = ctx.Input("X"); + auto y = ctx.Output("Y"); + + auto* x_data = x->data(); + y->mutable_data(ctx.GetPlace()); + auto* y_data = y->data(); + + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + int label_rank = ctx.Input("Label")->dims().size(); + + if (label_rank == 2) { + // soft cross entropy + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + T sum = static_cast(0); + for (int j = 0; j < class_num; ++j) { + sum += label_data[index] * std::log(x_data[index]); + y_data[i] = -tolerable_value(sum); + index++; + } + } + } else { + // normal cross entropy + auto* label_data = ctx.Input("Label")->data(); + for (int i = 0; i < batch_size; ++i) { + int index = i * class_num + label_data[i]; + y_data[i] = -tolerable_value(std::log(x_data[index])); + } } } }; template -class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { +class CrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto X = ctx.Input("X"); - auto dX = ctx.Output(framework::GradVarName("X")); - auto dY = ctx.Input(framework::GradVarName("Y")); - auto label = ctx.Input("label"); + auto x = ctx.Input("X"); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("Label"); - auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - auto* dYdata = dY->template data(); - auto* Xdata = X->template data(); - auto* label_data = label->data(); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + auto* x_data = x->data(); - const int batch_size = X->dims()[0]; - const int class_num = X->dims()[1]; + int batch_size = x->dims()[0]; + int class_num = x->dims()[1]; + int label_rank = ctx.Input("Label")->dims().size(); // TODO(qingqing): make zero setting an common function. - memset(dXdata, 0, sizeof(T) * batch_size * class_num); - for (int i = 0; i < batch_size; ++i) { - int index = i * class_num + label_data[i]; - dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); + if (label_rank == 2) { + // soft cross entropy + auto* label_data = ctx.Input("Label")->data(); + int index = 0; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < class_num; ++j) { + dx_data[index] = -label_data[index] * dy_data[i] / x_data[index]; + index++; + } + } + } else { + // normal cross entropy + auto* label_data = label->data(); + memset(dx_data, 0, sizeof(T) * batch_size * class_num); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); + int index = i * class_num + label_data[i]; + dx_data[index] = -dy_data[i] / x_data[index]; + } } } }; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae5..13e11fe82a3 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -32,7 +32,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add); -USE_OP(onehot_cross_entropy); +USE_OP(cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index c2fc102a8b8..b845bbc6809 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -5,13 +5,13 @@ from op_test import OpTest class TestCrossEntropy(OpTest): def setUp(self): - self.op_type = "onehot_cross_entropy" + self.op_type = "cross_entropy" batch_size = 30 class_num = 10 X = numpy.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") label = (class_num / 2) * numpy.ones(batch_size).astype("int32") - self.inputs = {'X': X, 'label': label} + self.inputs = {'X': X, 'Label': label} Y = [] for i in range(0, batch_size): Y.append(-numpy.log(X[i][label[i]])) @@ -24,5 +24,26 @@ class TestCrossEntropy(OpTest): self.check_grad(['X'], 'Y') +class TestCrossEntropySoftLabel(OpTest): + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + X = numpy.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label = numpy.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float32") + label /= label.sum(axis=1, keepdims=True) + self.inputs = {'X': X, 'Label': label} + Y = (-label * numpy.log(X)).sum(axis=1) + self.outputs = {'Y': numpy.array(Y).astype("float32")} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.05) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mnist.py b/python/paddle/v2/framework/tests/test_mnist.py index f6f8f49b797..10f2810ad0b 100644 --- a/python/paddle/v2/framework/tests/test_mnist.py +++ b/python/paddle/v2/framework/tests/test_mnist.py @@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): def cross_entropy_layer(net, input, label): cost_name = "cross_entropy_%d" % uniq_id() cross_entropy_op = Operator( - "onehot_cross_entropy", X=input, label=label, Y=cost_name) + "cross_entropy", X=input, label=label, Y=cost_name) net.append_op(cross_entropy_op) scope.new_var(cost_name) net.infer_shape(scope) -- GitLab