From 10c29185587230cab15e187c4b69975b74360d9c Mon Sep 17 00:00:00 2001 From: TFbunny Date: Sat, 22 Aug 2020 09:25:48 -0400 Subject: [PATCH] fix speed bottleneck for SrandInit and Shuffle in GPU-RandomChoiceWithMask --- .../cuda_impl/random_choice_with_mask_impl.cu | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu index 436f96213..20478bcb8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cu @@ -155,7 +155,7 @@ __global__ void Sort(const int ceil_power2, T *rank_buff) { __global__ void SrandInit(const int ceil_power2, curandState *globalState, const int seedc) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < ceil_power2; i += blockDim.x * gridDim.x) { - curand_init(seedc, i, 0, &globalState[i]); + curand_init(seedc, threadIdx.x, 0, &globalState[i]); } } @@ -163,21 +163,20 @@ template __global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank_buff) { int limit = ceil_power2 + 1; int value; - for (size_t i = 2; i <= ceil_power2; i <<= 1) { - for (size_t j = (i >> 1); j > 0; j >>= 1) { - for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { - size_t tid_comp = tid ^ j; - if (tid_comp > tid) { - value = static_cast(curand(&globalState[tid])); - if (value & 1) { - if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) { - Swap(&rank_buff[tid], &rank_buff[tid_comp]); - } + size_t i = ceil_power2; + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + value = static_cast(curand(&globalState[tid])); + if (value & 1) { + if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) { + Swap(&rank_buff[tid], &rank_buff[tid_comp]); } } } - __syncthreads(); } + __syncthreads(); } } -- GitLab