提交 d8046da0 编写于 作者: X Xinghai Sun

Use soft_label attribute for cross-entropy.

上级 8e7fe8ca
...@@ -25,25 +25,32 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -25,25 +25,32 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
"Input(X) of CrossEntropyOp must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) of CrossEntropyOp must not be null."); "Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
"Output(Y) of CrossEntropyOp must not be null.");
auto x = ctx.Input<Tensor>("X");
auto *x = ctx.Input<Tensor>("X"); auto label = ctx.Input<Tensor>("Label");
auto *label = ctx.Input<Tensor>("Label"); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "X's rank must be 2."); "Input(Label)'s rank must be 2.");
PADDLE_ASSERT(label->dims().size() == 1 || label->dims().size() == 2); // TODO(xinghai-sun): remove this check after swtiching to bool
if (label->dims().size() == 2) { PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
// soft cross entropy ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims(), label->dims()); PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else { } else {
// normal cross entropy PADDLE_ENFORCE_EQ(label->dims()[1], 1,
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]); "If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
} }
ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1}); ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1});
} }
}; };
...@@ -54,12 +61,41 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -54,12 +61,41 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
"Input(X) of CrossEntropyOp must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null.");
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
}
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
dx->Resize(x->dims()); dx->Resize(x->dims());
} }
}; };
...@@ -72,22 +108,31 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -72,22 +108,31 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of CrossEntropyOp"); AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp"); AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp"); AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. CrossEntropy Operator.
The second input (Label tensor) supports two kinds of shapes: It supports both standard cross-entropy and soft-label cross-entropy loss
1) Rank(Label) = 1, Label[i] indicates the class index for sample i: computation.
1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]]) Y[i] = -log(X[i, Label[i]])
2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j 2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j
for sample i: for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label Please make sure that in this case the summuation of each row of Label
equals one. If each row of Label has only one non-zero element (equals 1), equals one.
it degenerates to a standard one-hot representation.
3) One-hot cross-entropy with vecterized Input(Label):
As a special case of 2), when each row of Input(Label) has only one
non-zero element (equals 1), soft-label cross-entropy degenerates to a
one-hot cross-entropy with one-hot label representation.
)DOC"); )DOC");
} }
}; };
......
...@@ -13,27 +13,13 @@ ...@@ -13,27 +13,13 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) { const int N, const int D) {
...@@ -53,9 +39,9 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -53,9 +39,9 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0); T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) { for (int j = 0; j < D; j++) {
sum += label[i * D + j] * log(X[i * D + j]); sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
} }
Y[i] = -tolerable_value(sum); Y[i] = -sum;
} }
} }
...@@ -85,6 +71,7 @@ template <typename T> ...@@ -85,6 +71,7 @@ template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N, const T* label, const int N,
const int D) { const int D) {
// TOOD(qingqing): optimize for this kernel
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
for (int j = 0; j < D; ++j) { for (int j = 0; j < D; ++j) {
...@@ -115,14 +102,11 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -115,14 +102,11 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream // TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
int label_rank = label->dims().size(); if (ctx.Attr<int>("soft_label") == 1) {
if (label_rank == 2) {
// soft cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d); d);
} else { } else {
// normal cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<int>(); auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d); CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
} }
...@@ -153,14 +137,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -153,14 +137,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
grid = (n + block - 1) / block; grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream // TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
int label_rank = label->dims().size(); if (ctx.Attr<int>("soft_label") == 1) {
if (label_rank == 2) {
// soft cross entropy
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>( SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d); dx_data, dy_data, x_data, label_data, n, d);
} else { } else {
// normal cross entropy
auto* label_data = label->data<int>(); auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data, CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
label_data, n, d); label_data, n, d);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,21 +22,15 @@ namespace operators { ...@@ -21,21 +22,15 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
inline T tolerable_value(const T x) { HOSTDEVICE T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value, PADDLE_ASSERT(std::is_floating_point<T>::value);
"tolerable_value works only on float, "
"double and double double.");
const T kApproInf = 1e20; const T kApproInf = 1e20;
if (x == INFINITY) { if (x == INFINITY) {
return kApproInf; return kApproInf;
} }
if (x == -INFINITY) { if (x == -INFINITY) {
return -kApproInf; return -kApproInf;
} }
return x; return x;
} }
...@@ -55,22 +50,19 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -55,22 +50,19 @@ class CrossEntropyOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
int label_rank = ctx.Input<Tensor>("Label")->dims().size();
if (label_rank == 2) { if (ctx.Attr<int>("soft_label") == 1) {
// soft cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
T sum = static_cast<T>(0); T sum = static_cast<T>(0);
for (int j = 0; j < class_num; ++j) { for (int j = 0; j < class_num; ++j) {
sum += label_data[index] * std::log(x_data[index]); sum += label_data[index] * tolerable_value(std::log(x_data[index]));
y_data[i] = -tolerable_value(sum); y_data[i] = -sum;
index++; index++;
} }
} }
} else { } else {
// normal cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<int>(); auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
...@@ -98,11 +90,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -98,11 +90,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
int label_rank = ctx.Input<Tensor>("Label")->dims().size();
// TODO(qingqing): make zero setting an common function. // TODO(qingqing): make zero setting an common function.
if (label_rank == 2) { if (ctx.Attr<int>("soft_label") == 1) {
// soft cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -112,7 +102,6 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -112,7 +102,6 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
} }
} }
} else { } else {
// normal cross entropy
auto* label_data = label->data<int>(); auto* label_data = label->data<int>();
memset(dx_data, 0, sizeof(T) * batch_size * class_num); memset(dx_data, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
import unittest import unittest
import numpy import numpy as np
from op_test import OpTest from op_test import OpTest
class TestOnehotCrossEntropyOp(OpTest): class TestCrossEntropyOp1(OpTest):
"""Test standard cross-entropy, with index representation of labels.
"""
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 30 batch_size = 30
class_num = 10 class_num = 10
X = np.random.uniform(0.1, 1.0,
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix(
cross_entropy = numpy.asmatrix( [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
dtype="float32") dtype="float32")
self.inputs = {"X": X, "Label": labels} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -26,20 +28,55 @@ class TestOnehotCrossEntropyOp(OpTest): ...@@ -26,20 +28,55 @@ class TestOnehotCrossEntropyOp(OpTest):
self.check_grad(["X"], "Y") self.check_grad(["X"], "Y")
class TestCrossEntropySoftLabel(OpTest): class TestCrossEntropyOp2(OpTest):
"""Test soft-label cross-entropy, with vecterized soft labels.
"""
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 30 batch_size = 10
class_num = 10 class_num = 5
X = numpy.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label = numpy.random.uniform(0.1, 1.0, label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True) label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {'X': X, 'Label': label}
cross_entropy = (-label * numpy.log(X)).sum( self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
class TestCrossEntropyOp3(OpTest):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape)
label[np.arange(batch_size), label_index] = 1
cross_entropy = np.asmatrix(
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])],
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy} self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册