提交 8b8ad6b1 编写于 作者: C caoying03

fix implementations of supporting soft labels.

上级 bb58b63b
......@@ -28,7 +28,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
}
}
......@@ -39,7 +39,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
sum += label[i * D + j] * TolerableValue<T>()(log(X[i * D + j]));
}
Y[i] = -sum;
}
......
......@@ -22,17 +22,16 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
return x;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
};
template <typename T>
class CrossEntropyOpKernel : public framework::OpKernel {
......@@ -57,7 +56,8 @@ class CrossEntropyOpKernel : public framework::OpKernel {
for (int i = 0; i < batch_size; ++i) {
T sum = static_cast<T>(0);
for (int j = 0; j < class_num; ++j) {
sum += label_data[index] * tolerable_value(std::log(x_data[index]));
sum +=
label_data[index] * TolerableValue<T>()(std::log(x_data[index]));
y_data[i] = -sum;
index++;
}
......@@ -66,7 +66,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
y_data[i] = -tolerable_value(std::log(x_data[index]));
y_data[i] = -TolerableValue<T>()(std::log(x_data[index]));
}
}
}
......
......@@ -18,7 +18,7 @@ namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUPlace, float>;
template class SoftmaxFunctor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
......
......@@ -28,8 +28,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class SoftmaxFunctor {
public:
void operator()(const framework::Tensor* X, framework::Tensor* Y,
const framework::ExecutionContext& context) {
void operator()(const framework::ExecutionContext& context,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......
......@@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel {
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(X, Y, context);
math::SoftmaxFunctor<Place, T>()(context, X, Y);
}
};
......
......@@ -23,31 +23,31 @@ class SoftmaxWithCrossEntropyOpMaker
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
//(TODO caoying) replace int with boolean
AddAttr<int>("soft_label",
"(int, default 0), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(0);
AddAttr<bool>(
"softLabel",
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
AddInput("Logits",
"(Tensor, default Tensor<float>), The unscaled log probabilities "
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.")
.NotInGradient();
AddInput(
"Label",
"(Tensor, default Tensor<int>), The ground truth which is "
"a 1-D or 2-D tensor. "
"If soft_label is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If soft_label is set to 1, Label is a Tensor<float/double> "
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"tensor. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K].");
AddOutput(
"Softmax",
"(Tensor, default Tensor<float>), A 2-D tensor with shape [N x K]. "
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default Tensor<float>), A 1-D tensor. The cross "
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
"entropy loss with shape [N x 1].");
AddComment(R"DOC(
Cross entropy loss with softmax are used as the output layer extensively. This
......@@ -83,15 +83,39 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
"Input(Logits) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
"Output(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
"Output(Loss) should be not null.");
const Tensor* logits = ctx.Input<Tensor>("Logits");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(
logits->dims().size() == 2UL,
"The input of softmax_with_cross_entropy should be a 2-d tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 1UL,
"The label should be a 1-d tensor.");
ctx.Output<framework::LoDTensor>("Softmax")->Resize(logits->dims());
ctx.Output<framework::LoDTensor>("Loss")->Resize({logits->dims()[0], 1});
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
"If Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
"If Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
ctx.ShareLoD("Logits", /*->*/ "Softmax");
ctx.ShareLoD("Logits", /*->*/ "Loss");
}
};
......@@ -102,11 +126,28 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null");
"Input(Loss@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
"Input(Softmax) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Lable) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null.");
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
const Tensor* labels = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE(ctx.Input<Tensor>("Label")->dims().size() == 2UL,
"The labels should be a 2-D tensor.");
if (ctx.Attr<bool>("softLabel")) {
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
"When Attr(softLabel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels->dims()[1], 1,
"When Attr(softLabel) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
->Resize(ctx.Input<Tensor>("Softmax")->dims());
......
......@@ -24,25 +24,78 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void CrossEntropyKernel(T* out, const T* softmax_out,
const int* label, const int batch_size,
const int class_num) {
__global__ void CrossEntropy(T* out, const T* softmax_out, const int* labels,
const int batch_size, const int class_num) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
out[i] = -tolerable_value(std::log(softmax_out[i * class_num + label[i]]));
PADDLE_ASSERT(labels[i] >= 0 && labels[i] < class_num);
out[i] =
-TolerableValue<T>()(std::log(softmax_out[i * class_num + labels[i]]));
}
}
template <typename T>
__global__ void CrossEntropyWithSoftmaxGradKernel(T* softmax_out,
const int* label,
const int batch_size,
const int class_num) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < class_num);
softmax_out[i * class_num + label[i]] -= 1.;
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
const int* labels, const int batch_size,
const int class_num) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sample_idx = tid / class_num;
if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
__syncthreads();
if (tid < batch_size) {
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
out_grad[tid * class_num + labels[tid]] -= 1.;
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels,
const int batch_size,
const int class_num) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < batch_size * class_num) {
int row_ids = ids / class_num;
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
}
}
......@@ -52,27 +105,36 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
T* loss_data =
context.Output<Tensor>("Loss")->mutable_data<T>(context.GetPlace());
// Calculate ths softmax outputs.
const Tensor* logits = context.Input<Tensor>("Logits");
Tensor* softmax = context.Output<Tensor>("Softmax");
softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(logits, softmax, context);
T* softmax_out = softmax->data<T>();
// Calculate the cross entropy loss based on hard labels.
const int* label_data = context.Input<Tensor>("Label")->data<int>();
Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>();
T* softmax_out = softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
const int batch_size = logits->dims()[0];
const int class_num = logits->dims()[1];
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<grid, block>>>(loss_data, softmax_out, label_data,
batch_size, class_num);
if (context.Attr<bool>("softLabel")) {
const T* label_data = context.Input<Tensor>("Label")->data<T>();
block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(loss_data, softmax_out, label_data, class_num);
} else {
const int* label_data = context.Input<Tensor>("Label")->data<int>();
CrossEntropy<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(loss_data, softmax_out, label_data,
batch_size, class_num);
}
}
};
......@@ -82,7 +144,9 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
const Tensor* labels = context.Input<Tensor>("Label");
const T* loss_grad_data =
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
......@@ -90,14 +154,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
const int* label_data = context.Input<Tensor>("Label")->data<int>();
const int block = 512;
const int grid = (batch_size + block - 1) / block;
CrossEntropyWithSoftmaxGradKernel<T><<<grid, block>>>(
logit_grad_data, label_data, batch_size, class_num);
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
if (context.Attr<bool>("softLabel")) {
const T* label_data = labels->data<T>();
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(logit_grad_data, loss_grad_data,
label_data, batch_size, class_num);
} else {
const int* label_data = labels->data<int>();
CrossEntropyGrad<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(logit_grad_data, loss_grad_data,
label_data, batch_size, class_num);
}
}
};
......
......@@ -32,28 +32,35 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
"This kernel only runs on CPU.");
// Calculate ths softmax outputs.
const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax");
softmax->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::CPUPlace, T>()(logits, softmax, context);
Tensor* loss = context.Output<Tensor>("Loss");
// Calculate the cross entropy loss based on hard labels.
T* softmax_out = softmax->data<T>();
const int* label_data = context.Input<Tensor>("Label")->data<int>();
T* softmax_data = softmax->mutable_data<T>(context.GetPlace());
T* loss_data = loss->mutable_data<T>(context.GetPlace());
Tensor* loss = context.Output<Tensor>("Loss");
loss->mutable_data<T>(context.GetPlace());
T* loss_data = loss->data<T>();
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
const int batch_size = logits->dims()[0];
const int class_num = logits->dims()[1];
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
loss_data[i] = -tolerable_value(std::log(softmax_out[index]));
if (context.Attr<bool>("softLabel")) {
//(TODO caoying) the forward implementation can be further optimized.
// Current implementation is exactly cross entropy after softmax.
auto prob = EigenMatrix<T>::From(*softmax);
auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss_mat = EigenMatrix<T>::From(*loss);
loss_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
-((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else {
const int* label_data = labels->data<int>();
const int class_num = logits->dims()[1];
for (int i = 0; i < batch_size; ++i)
loss_data[i] = -TolerableValue<T>()(
std::log(softmax_data[i * class_num + label_data[i]]));
}
}
};
......@@ -62,18 +69,34 @@ template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Loss"));
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
const int* label_data = context.Input<Tensor>("Label")->data<int>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
logit_grad_data[index] -= 1.;
if (context.Attr<bool>("softLabel")) {
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
auto lbl_mat = EigenMatrix<T>::From(*labels);
logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) -
lbl_mat;
} else {
const int batch_size = logit_grad->dims()[0];
const int* label_data = labels->data<int>();
const T* out_grad_data = out_grad->data<T>();
T* logit_grad_data = logit_grad->data<T>();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
logit_grad_data[index] =
(out_grad_data[i] * logit_grad_data[index] - 1.);
}
}
}
};
......
......@@ -6,22 +6,23 @@ from test_softmax_op import stable_softmax
class TestSoftmaxWithCrossEntropyOp(OpTest):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
MAX_BATCH_SIZE = 23
MAX_CLASS_NUM = 17
batch_size = np.random.randint(1, MAX_BATCH_SIZE, 1)[0]
class_num = np.random.randint(2, MAX_CLASS_NUM, 1)[0]
batch_size = 3
class_num = 37
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, batch_size, dtype="int32")
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32")
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i]])] for i in range(softmax.shape[0])],
[[-np.log(softmax[i][labels[i][0]])]
for i in range(softmax.shape[0])],
dtype="float32")
self.inputs = {"Logits": logits, "Label": labels}
......@@ -34,5 +35,36 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
class TestSoftmaxWithCrossEntropyOp2(OpTest):
"""
Test softmax with cross entropy operator with soft labels.
"""
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 2
class_num = 17
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
labels /= np.sum(labels, axis=1, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
self.attrs = {"softLabel": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册