提交 201c2bcf 编写于 作者: C caoying03

delete redundant codes.

上级 6735585b
...@@ -42,9 +42,8 @@ __device__ __forceinline__ T sum_single_warp(T val) { ...@@ -42,9 +42,8 @@ __device__ __forceinline__ T sum_single_warp(T val) {
return val; return val;
} }
// This kernel is called when the class number is less than or equal to 512.
template <typename T> template <typename T>
__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label, __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) { const int class_num) {
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ T d_sum[]; extern __shared__ T d_sum[];
...@@ -69,33 +68,6 @@ __global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label, ...@@ -69,33 +68,6 @@ __global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
if (tid == 0) Y[blockIdx.x] = -val; if (tid == 0) Y[blockIdx.x] = -val;
} }
// This kernel is called when the class number is larger than 512.
template <typename T, int BlockSize>
__global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
__shared__ T d_sum[BlockSize];
int next_idx = blockIdx.x * class_num + tid;
d_sum[tid] = 0;
int cur_idx = tid;
while (cur_idx < class_num) {
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += BlockSize;
cur_idx += BlockSize;
}
__syncthreads();
for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
// TODO(qingqing): make zero setting a common function. // TODO(qingqing): make zero setting a common function.
template <typename T> template <typename T>
__global__ void zero(T* X, const int N) { __global__ void zero(T* X, const int N) {
...@@ -146,26 +118,19 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -146,26 +118,19 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
int block = 512;
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
if (class_num > 512) { int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel2<
T, 512><<<batch_size, block, 0, SoftCrossEntropyKernel<
reinterpret_cast<const platform::CUDADeviceContext&>( T><<<batch_size, block, block * sizeof(T),
ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num);
} else {
int block_size = pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel1<
T><<<batch_size, block_size, block_size * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>( reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream()>>>(y_data, x_data, label_data, class_num); .stream()>>>(y_data, x_data, label_data, class_num);
}
} else { } else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>(); auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
int block = 512;
int grid = (batch_size + block - 1) / block; int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<< CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
......
...@@ -4,19 +4,21 @@ from op_test import OpTest ...@@ -4,19 +4,21 @@ from op_test import OpTest
class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp1(OpTest):
"""Test standard cross-entropy, with index representation of labels. """Test cross-entropy with discrete one-hot labels.
""" """
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 30 batch_size = 30
class_num = 10 class_num = 10
X = np.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32") label = np.random.randint(0, class_num, (batch_size, 1), dtype="int32")
cross_entropy = np.asmatrix( cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])], [[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float32") dtype="float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": False} self.attrs = {"soft_label": False}
...@@ -29,14 +31,14 @@ class TestCrossEntropyOp1(OpTest): ...@@ -29,14 +31,14 @@ class TestCrossEntropyOp1(OpTest):
class TestCrossEntropyOp2(OpTest): class TestCrossEntropyOp2(OpTest):
"""Test soft-label cross-entropy, with vecterized soft labels. """Test cross-entropy with vectorized soft labels.
""" """
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 5 batch_size = 5
# this setting tests threads in more than one wrap.
class_num = 37 class_num = 37
X = np.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0, label = np.random.uniform(0.1, 1.0,
...@@ -44,6 +46,7 @@ class TestCrossEntropyOp2(OpTest): ...@@ -44,6 +46,7 @@ class TestCrossEntropyOp2(OpTest):
label /= label.sum(axis=1, keepdims=True) label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum( cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True} self.attrs = {"soft_label": True}
...@@ -56,15 +59,14 @@ class TestCrossEntropyOp2(OpTest): ...@@ -56,15 +59,14 @@ class TestCrossEntropyOp2(OpTest):
class TestCrossEntropyOp3(OpTest): class TestCrossEntropyOp3(OpTest):
"""Test one-hot cross-entropy, with vecterized one-hot representation of """Test cross-entropy with vectorized one-hot representation of labels.
labels.
""" """
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 5 batch_size = 5
# this setting tests all threads in one wrap.
class_num = 17 class_num = 17
X = np.random.uniform(0.1, 1.0, X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label_index = np.random.randint( label_index = np.random.randint(
...@@ -76,33 +78,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -76,33 +78,7 @@ class TestCrossEntropyOp3(OpTest):
dtype="float32") dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum( cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.05)
class TestCrossEntropyOp4(OpTest):
"""Test soft-label cross-entropy.
This unittest tests the gpu kernel for layer size excesses 512.
"""
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 2
class_num = 517
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": True} self.attrs = {"soft_label": True}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册