From 58f40144c91f6dfe39f63a6fde89e8baa57f2423 Mon Sep 17 00:00:00 2001 From: wawltor Date: Fri, 6 May 2022 15:02:32 +0800 Subject: [PATCH] Fix the race condition in cumsum operator (#42205) (#42500) * Fix the race condition in cumsum operator * Optimize cumsum operator Co-authored-by: Leo Chen <39020268+leo0519@users.noreply.github.com> --- paddle/phi/kernels/gpu/cumsum_kernel.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/gpu/cumsum_kernel.cu b/paddle/phi/kernels/gpu/cumsum_kernel.cu index a253e6f4ad2..9073bcb6c02 100644 --- a/paddle/phi/kernels/gpu/cumsum_kernel.cu +++ b/paddle/phi/kernels/gpu/cumsum_kernel.cu @@ -39,14 +39,12 @@ __device__ void BlockReverse( int tx = threadIdx.x; int offset = tx; - int in_index = src_base + offset; - if (offset >= valid_item) { - sh_mem[offset] = 0; - } else { - int sh_mem_index = BLOCK_SIZE - offset - 1; - T data = idata[in_index]; - sh_mem[sh_mem_index] = data; + T src_data = 0; + int src_offset = BLOCK_SIZE - offset - 1; + if (src_offset < valid_item) { + src_data = idata[src_base + src_offset]; } + sh_mem[offset] = src_data; __syncthreads(); int out_index = dst_base - offset; -- GitLab