未验证 提交 a6794926 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix randperm out of bound bug (#42057)

上级 b20683c0
...@@ -36,26 +36,29 @@ DECLARE_bool(use_curand); ...@@ -36,26 +36,29 @@ DECLARE_bool(use_curand);
namespace phi { namespace phi {
template <typename T> template <typename keyT, typename dataT>
__global__ void SwapRepeatKernel( __global__ void SwapRepeatKernel(keyT* key_out_data,
int* key, T* data, int n, uint64_t seed, uint64_t offset) { dataT* out_data,
int n,
uint64_t seed,
uint64_t offset) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x); size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < n) return; if (idx >= n - 1) return; // out of range
bool first_repeat = false; bool is_first_repeat = false;
if (data[idx] == data[idx + 1]) { if (key_out_data[idx] == key_out_data[idx + 1]) {
if (idx == 0) { if (idx == 0) {
first_repeat = true; is_first_repeat = true;
} else if (data[idx] != data[idx - 1]) { } else if (key_out_data[idx] != key_out_data[idx - 1]) {
first_repeat = true; is_first_repeat = true;
} }
} }
if (!first_repeat) return; if (!is_first_repeat) return;
int repeat_size = 1; int repeat_size = 1;
for (int i = idx; i < n; ++i) { for (int i = idx; i < n; ++i) {
if (data[i] == data[i + 1]) { if (key_out_data[i] == key_out_data[i + 1]) {
++repeat_size; ++repeat_size;
} else { } else {
break; break;
...@@ -74,9 +77,9 @@ __global__ void SwapRepeatKernel( ...@@ -74,9 +77,9 @@ __global__ void SwapRepeatKernel(
uint32_t r = hiprand(&state) % (i + 1); uint32_t r = hiprand(&state) % (i + 1);
#endif #endif
if (r != i) { if (r != i) {
T tmp = data[idx + i]; dataT tmp = out_data[idx + i];
data[idx + i] = data[idx + r]; out_data[idx + i] = out_data[idx + r];
data[idx + r] = tmp; out_data[idx + r] = tmp;
} }
} }
} }
...@@ -138,10 +141,10 @@ void RandpermRawKernel( ...@@ -138,10 +141,10 @@ void RandpermRawKernel(
auto seed_offset = gen_cuda->IncrementOffset(n); auto seed_offset = gen_cuda->IncrementOffset(n);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
SwapRepeatKernel<T><<<config.block_per_grid.x, SwapRepeatKernel<<<config.block_per_grid.x,
config.thread_per_block.x, config.thread_per_block.x,
0, 0,
dev_ctx.stream()>>>( dev_ctx.stream()>>>(
key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second); key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册