提交 a3a8a090 编写于 作者: C caoying03

optimize cross entropy kernel by using reduce.

上级 414a7a1e
......@@ -32,16 +32,33 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
}
}
template <typename T>
template <typename T, int blockSize>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int N, const int D) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
T sum = static_cast<T>(0);
for (int j = 0; j < D; j++) {
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
int tid = threadIdx.x;
__shared__ T d_sum[blockSize];
int next_idx = blockIdx.x * D + tid;
d_sum[tid] = 0;
int cur_idx = tid;
while (cur_idx < D) {
d_sum[tid] += tolerable_value(std::log(X[next_idx])) * label[next_idx];
next_idx += blockSize;
cur_idx += blockSize;
}
__syncthreads();
for (int stride = blockSize >> 1; stride > 0; stride >>= 1) {
__syncthreads();
if (tid < stride) {
next_idx = tid + stride;
d_sum[tid] += d_sum[next_idx];
}
}
Y[i] = -sum;
__syncthreads();
if (tid == 0) {
Y[blockIdx.x] = -d_sum[0];
}
}
......@@ -104,8 +121,9 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
// base on ExecutionContext.
if (ctx.Attr<int>("soft_label") == 1) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d);
grid = d;
SoftCrossEntropyKernel<T, 512><<<grid, block>>>(y_data, x_data,
label_data, n, d);
} else {
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
......
......@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype="float32")
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0}
self.attrs = {"soft_label": 0}
def test_check_output(self):
self.check_output()
......@@ -34,8 +34,8 @@ class TestCrossEntropyOp2(OpTest):
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 10
class_num = 5
batch_size = 13
class_num = 37
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = np.random.uniform(0.1, 1.0,
......@@ -43,15 +43,16 @@ class TestCrossEntropyOp2(OpTest):
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(["X"], "Y", max_relative_error=0.05)
class TestCrossEntropyOp3(OpTest):
......@@ -61,8 +62,8 @@ class TestCrossEntropyOp3(OpTest):
def setUp(self):
self.op_type = "cross_entropy"
batch_size = 30
class_num = 10
batch_size = 13
class_num = 37
X = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label_index = np.random.randint(
......@@ -74,15 +75,15 @@ class TestCrossEntropyOp3(OpTest):
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1}
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"soft_label": 1}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(["X"], "Y", max_relative_error=0.05)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册