From a679492639ab6ade4f2f69fd26443d7d883c37da Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 27 Apr 2022 10:08:20 +0800 Subject: [PATCH] fix randperm out of bound bug (#42057) --- paddle/phi/kernels/gpu/randperm_kernel.cu | 39 ++++++++++++----------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/gpu/randperm_kernel.cu b/paddle/phi/kernels/gpu/randperm_kernel.cu index 4e488ed470d..94f063512c0 100644 --- a/paddle/phi/kernels/gpu/randperm_kernel.cu +++ b/paddle/phi/kernels/gpu/randperm_kernel.cu @@ -36,26 +36,29 @@ DECLARE_bool(use_curand); namespace phi { -template -__global__ void SwapRepeatKernel( - int* key, T* data, int n, uint64_t seed, uint64_t offset) { +template +__global__ void SwapRepeatKernel(keyT* key_out_data, + dataT* out_data, + int n, + uint64_t seed, + uint64_t offset) { size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx < n) return; + if (idx >= n - 1) return; // out of range - bool first_repeat = false; - if (data[idx] == data[idx + 1]) { + bool is_first_repeat = false; + if (key_out_data[idx] == key_out_data[idx + 1]) { if (idx == 0) { - first_repeat = true; - } else if (data[idx] != data[idx - 1]) { - first_repeat = true; + is_first_repeat = true; + } else if (key_out_data[idx] != key_out_data[idx - 1]) { + is_first_repeat = true; } } - if (!first_repeat) return; + if (!is_first_repeat) return; int repeat_size = 1; for (int i = idx; i < n; ++i) { - if (data[i] == data[i + 1]) { + if (key_out_data[i] == key_out_data[i + 1]) { ++repeat_size; } else { break; @@ -74,9 +77,9 @@ __global__ void SwapRepeatKernel( uint32_t r = hiprand(&state) % (i + 1); #endif if (r != i) { - T tmp = data[idx + i]; - data[idx + i] = data[idx + r]; - data[idx + r] = tmp; + dataT tmp = out_data[idx + i]; + out_data[idx + i] = out_data[idx + r]; + out_data[idx + r] = tmp; } } } @@ -138,10 +141,10 @@ void RandpermRawKernel( auto seed_offset = gen_cuda->IncrementOffset(n); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); - SwapRepeatKernel<<>>( + SwapRepeatKernel<<>>( key_out.data(), out_data, n, seed_offset.first, seed_offset.second); } -- GitLab