From 22b9ab052d157d63ef2dd6029da76ecf41870d4b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 24 Feb 2018 14:32:49 +0800 Subject: [PATCH] refine Sum --- .../fluid/operators/elementwise_op_function.h | 46 ++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 5c783035309..a6c73598e00 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -357,6 +357,14 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, } } #ifdef __NVCC__ +// __shfl_down has been deprecated as of CUDA 9.0 +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} +#endif + template static __global__ void ElemwiseGradBroadcast1CUDAKernel( const T* x, const T* y, const T* out, const T* dout, int h, int w, @@ -381,15 +389,22 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( } while (i < h); if (dy) { + T val = shm[threadIdx.x]; + int warpSize = 32; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(0, val, offset); + __syncthreads(); + shm[tid] = 0; + if (threadIdx.x % 32 == 0) { + shm[threadIdx.x / 32] = val; + } - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = shm[threadIdx.x]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(0, val, offset); - // Sum, could be optimized if (threadIdx.x == 0) { - for (int k = 1; k < h; ++k) { - shm[0] += shm[k]; - } dy[j] = shm[0]; } } @@ -468,15 +483,22 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( } if (dy) { + T val = shm[threadIdx.x]; + int warpSize = 32; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(0, val, offset); + __syncthreads(); - int h = pre * post; - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + shm[tid] = 0; + if (threadIdx.x % 32 == 0) { + shm[threadIdx.x / 32] = val; + } - // Sum, could be optimized - if (tid == 0) { - for (int i = 1; i < h; ++i) { - shm[0] += shm[i]; - } + val = shm[threadIdx.x]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(0, val, offset); + + if (threadIdx.x == 0) { dy[j] = shm[0]; } } -- GitLab