提交 6735585b 编写于 作者: C caoying03

fix cpu kernel with soft labels.

上级 30bfaab3
......@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
num_samples, infer_width, inference_data, label_data, accuracy_data);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(num_samples, infer_width, inference_data, label_data,
accuracy_data);
}
};
......
......@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) should be not null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1.");
"If Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx.Output<Tensor>("Y")->Resize({x->dims()[0], 1});
......@@ -57,35 +58,36 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) must not be null.");
"Input(Label) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) must not be null.");
"Input(Y@GRAD) shoudl be not null.");
auto x = ctx.Input<Tensor>("X");
auto label = ctx.Input<Tensor>("Label");
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy->dims().size(), 2,
"Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2.");
"Input(Label)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must "
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1.");
"The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal.");
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1.");
"When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -98,12 +100,26 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<bool>("soft_label", "Is soft label. Default zero.")
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator.");
AddInput("Label",
"(Tensor, default Tensor<int>), the ground truth which is "
"a 1-D or 2-D tensor. "
"When soft_label is set to 0, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When soft_label is set to 1, `Label` is a Tensor<float/double> "
"with shape [N x K].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 1-D tensor "
"with shape [N x 1]. The cross entropy loss.");
AddAttr<bool>(
"soft_label",
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
AddComment(R"DOC(
CrossEntropy Operator.
......
......@@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += __shfl_down(val, 16);
val += __shfl_down(val, 8);
val += __shfl_down(val, 4);
val += __shfl_down(val, 2);
val += __shfl_down(val, 1);
return val;
}
// This kernel is called when the class number is less than or equal to 512.
template <typename T>
__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
extern __shared__ T d_sum[];
d_sum[tid] = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
// This kernel is called when the class number is larger than 512.
template <typename T, int BlockSize>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int N, const int D) {
__global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
__shared__ T d_sum[BlockSize];
int next_idx = blockIdx.x * D + tid;
int next_idx = blockIdx.x * class_num + tid;
d_sum[tid] = 0;
int cur_idx = tid;
while (cur_idx < D) {
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += BlockSize;
cur_idx += BlockSize;
}
__syncthreads();
for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) {
for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
if (tid < stride) {
next_idx = tid + stride;
d_sum[tid] += d_sum[next_idx];
}
}
__syncthreads();
if (tid == 0) {
Y[blockIdx.x] = -d_sum[0];
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
// TODO(qingqing): make zero setting an common function.
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
......@@ -88,11 +122,9 @@ template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
int row_ids = blockIdx.x * blockDim.x + threadIdx.x;
int col_ids = blockIdx.y * blockDim.y + threadIdx.y;
int ids = row_ids * D + col_ids;
int ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < N * D) {
int row_ids = ids / D;
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
}
}
......@@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace());
auto* y_data = y->data<T>();
int n = x->dims()[0];
int d = x->dims()[1];
int batch_size = x->dims()[0];
int class_num = x->dims()[1];
int block = 512;
int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
grid = d;
SoftCrossEntropyKernel<T, 512><<<grid, block>>>(y_data, x_data,
label_data, n, d);
if (class_num > 512) {
SoftCrossEntropyKernel2<
T, 512><<<batch_size, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
} else {
int block_size = pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel1<
T><<<batch_size, block_size, block_size * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
}
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(y_data, x_data, label_data,
batch_size, class_num);
}
}
};
......@@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
int n = x->dims()[0];
int d = x->dims()[1];
int block = 512;
int grid = (n * d + block - 1) / block;
zero<T><<<grid, block>>>(dx_data, n * d);
grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, n * d);
if (ctx.Attr<bool>("soft_label")) {
int block_x = 32;
int block_y = 32;
dim3 block(block_x, block_y);
dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y);
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d);
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
n, d);
} else {
auto* label_data = label->data<int>();
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
label_data, n, d);
CrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
n, d);
}
}
};
......
......@@ -31,12 +31,8 @@ struct TolerableValue {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
if (x == INFINITY) return kApproInf;
if (x == -INFINITY) return -kApproInf;
return x;
}
};
......@@ -58,11 +54,8 @@ class CrossEntropyOpKernel : public framework::OpKernel {
auto lbl_mat = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*y);
// loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
// prob.log().unaryExpr(TolerableValue<T>());
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl_mat * prob.log())
-((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
} else {
......
......@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D);
LookupTable<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(output, table, ids, N, K, D);
}
};
......@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
K, D);
LookupTableGrad<T, 128, 8, 8><<<
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
context.device_context())
.stream()>>>(d_table, d_output, ids, N, K, D);
}
};
......
......@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel.
dim3 threads(256, 1);
dim3 grid(input_height, 1);
KeMatrixTopK<T, 5, 256><<<grid, threads>>>(
output_data, output->dims()[1], indices_data, input_data, input_width,
input_width, int(k));
KeMatrixTopK<T, 5, 256><<<
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(output_data, output->dims()[1],
indices_data, input_data,
input_width, input_width, int(k));
}
};
......
......@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype="float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': False}
self.attrs = {"soft_label": False}
def test_check_output(self):
self.check_output()
......@@ -34,7 +34,8 @@ class TestCrossEntropyOp2(OpTest):
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 13
batch_size = 5
# this setting tests threads in more than one wrap.
class_num = 37
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
......@@ -43,9 +44,9 @@ class TestCrossEntropyOp2(OpTest):
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': True}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
......@@ -61,8 +62,9 @@ class TestCrossEntropyOp3(OpTest):
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 13
class_num = 37
batch_size = 5
# this setting tests all threads in one wrap.
class_num = 17
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label_index = np.random.randint(
......@@ -74,9 +76,36 @@ class TestCrossEntropyOp3(OpTest):
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': True}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
class TestCrossEntropyOp4(OpTest):
"""Test soft-label cross-entropy.
This unittest tests the gpu kernel for layer size excesses 512.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 2
class_num = 517
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册