未验证 提交 a6794926 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix randperm out of bound bug (#42057)

上级 b20683c0
......@@ -36,26 +36,29 @@ DECLARE_bool(use_curand);
namespace phi {
template <typename T>
__global__ void SwapRepeatKernel(
int* key, T* data, int n, uint64_t seed, uint64_t offset) {
template <typename keyT, typename dataT>
__global__ void SwapRepeatKernel(keyT* key_out_data,
dataT* out_data,
int n,
uint64_t seed,
uint64_t offset) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < n) return;
if (idx >= n - 1) return; // out of range
bool first_repeat = false;
if (data[idx] == data[idx + 1]) {
bool is_first_repeat = false;
if (key_out_data[idx] == key_out_data[idx + 1]) {
if (idx == 0) {
first_repeat = true;
} else if (data[idx] != data[idx - 1]) {
first_repeat = true;
is_first_repeat = true;
} else if (key_out_data[idx] != key_out_data[idx - 1]) {
is_first_repeat = true;
}
}
if (!first_repeat) return;
if (!is_first_repeat) return;
int repeat_size = 1;
for (int i = idx; i < n; ++i) {
if (data[i] == data[i + 1]) {
if (key_out_data[i] == key_out_data[i + 1]) {
++repeat_size;
} else {
break;
......@@ -74,9 +77,9 @@ __global__ void SwapRepeatKernel(
uint32_t r = hiprand(&state) % (i + 1);
#endif
if (r != i) {
T tmp = data[idx + i];
data[idx + i] = data[idx + r];
data[idx + r] = tmp;
dataT tmp = out_data[idx + i];
out_data[idx + i] = out_data[idx + r];
out_data[idx + r] = tmp;
}
}
}
......@@ -138,10 +141,10 @@ void RandpermRawKernel(
auto seed_offset = gen_cuda->IncrementOffset(n);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
SwapRepeatKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
SwapRepeatKernel<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(
key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册