提交 d8046da0 编写于 作者: X Xinghai Sun

Use soft_label attribute for cross-entropy.

上级 8e7fe8ca
......@@ -25,25 +25,32 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CrossEntropyOp must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) of CrossEntropyOp must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) of CrossEntropyOp must not be null.");
auto *x = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "X's rank must be 2.");
PADDLE_ASSERT(label->dims().size() == 1 || label->dims().size() == 2);
if (label->dims().size() == 2) {
// soft cross entropy
PADDLE_ENFORCE_EQ(x->dims(), label->dims());
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else {
// normal cross entropy
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0]);
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
}
ctx.Output<LoDTensor>("Y")->Resize({x->dims()[0], 1});
}
};
......@@ -54,12 +61,41 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CrossEntropyOp must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null.");
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<int>("soft_label") == 1) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1.");
}
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
dx->Resize(x->dims());
}
};
......@@ -72,22 +108,31 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0);
AddComment(R"DOC(
CrossEntropy Operator.
The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j
for sample i:
2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label
equals one. If each row of Label has only one non-zero element (equals 1),
it degenerates to a standard one-hot representation.
equals one.
3) One-hot cross-entropy with vecterized Input(Label):
As a special case of 2), when each row of Input(Label) has only one
non-zero element (equals 1), soft-label cross-entropy degenerates to a
one-hot cross-entropy with one-hot label representation.
)DOC");
}
};
......
......@@ -13,27 +13,13 @@
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
......@@ -53,9 +39,9 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
sum += label[i * D + j] * log(X[i * D + j]);
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
}
Y[i] = -tolerable_value(sum);
Y[i] = -sum;
}
}
......@@ -85,6 +71,7 @@ template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
// TOOD(qingqing): optimize for this kernel
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
for (int j = 0; j < D; ++j) {
......@@ -115,14 +102,11 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
int label_rank = label->dims().size();
if (label_rank == 2) {
// soft cross entropy
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d);
} else {
// normal cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
}
......@@ -153,14 +137,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
int label_rank = label->dims().size();
if (label_rank == 2) {
// soft cross entropy
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d);
} else {
// normal cross entropy
auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
label_data, n, d);
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
......@@ -21,21 +22,15 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
......@@ -55,22 +50,19 @@ class CrossEntropyOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
int label_rank = ctx.Input<Tensor>("Label")->dims().size();
if (label_rank == 2) {
// soft cross entropy
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0;
for (int i = 0; i < batch_size; ++i) {
T sum = static_cast<T>(0);
for (int j = 0; j < class_num; ++j) {
sum += label_data[index] * std::log(x_data[index]);
y_data[i] = -tolerable_value(sum);
sum += label_data[index] * tolerable_value(std::log(x_data[index]));
y_data[i] = -sum;
index++;
}
}
} else {
// normal cross entropy
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
......@@ -98,11 +90,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
int label_rank = ctx.Input<Tensor>("Label")->dims().size();
// TODO(qingqing): make zero setting an common function.
if (label_rank == 2) {
// soft cross entropy
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0;
for (int i = 0; i < batch_size; ++i) {
......@@ -112,7 +102,6 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
}
}
} else {
// normal cross entropy
auto* label_data = label->data<int>();
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) {
......
import unittest
import numpy
import numpy as np
from op_test import OpTest
class TestOnehotCrossEntropyOp(OpTest):
class TestCrossEntropyOp1(OpTest):
"""Test standard cross-entropy, with index representation of labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
cross_entropy = numpy.asmatrix(
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32")
self.inputs = {"X": X, "Label": labels}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0}
def test_check_output(self):
self.check_output()
......@@ -26,20 +28,55 @@ class TestOnehotCrossEntropyOp(OpTest):
self.check_grad(["X"], "Y")
class TestCrossEntropySoftLabel(OpTest):
class TestCrossEntropyOp2(OpTest):
"""Test soft-label cross-entropy, with vecterized soft labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
batch_size = 10
class_num = 5
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
cross_entropy = (-label * numpy.log(X)).sum(
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
class TestCrossEntropyOp3(OpTest):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
labels.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label_index = np.random.randint(
0, class_num, (batch_size), dtype="int32")
label = np.zeros(X.shape)
label[np.arange(batch_size), label_index] = 1
cross_entropy = np.asmatrix(
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])],
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册