From b8938b448c29c8d9de938890726d2b84884eb56e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 24 Feb 2018 18:37:07 +0800 Subject: [PATCH] refine Sum --- paddle/fluid/operators/elementwise_op_function.h | 7 +++++-- paddle/fluid/platform/cuda_helper.h | 11 +++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 89050ec27..600524936 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -379,7 +379,8 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( } while (i < h); if (dy) { - val = platform::ReduceSum(val, tid); + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; } @@ -454,7 +455,9 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( } if (dy) { - val = platform::ReduceSum(val, threadIdx.x); + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; } diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 7b6ad1eb2..029ca609a 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -68,19 +68,22 @@ template __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { return __shfl_down(val, delta); } -#define CREATE_SHFL_MASK(mask, predicate) unsigned mask = 0u; +#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #else #define FULL_WARP_MASK 0xFFFFFFFF #define CREATE_SHFL_MASK(mask, predicate) \ - unsigned mask = __ballot_sync(FULL_WARP_MASK, (predicate)) + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) #endif template -__device__ T ReduceSum(T val, int tid) { +__device__ T reduceSum(T val, int tid, int len) { __shared__ T shm[32]; const int warpSize = 32; + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, tid < len); + for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(-1U, val, offset); + val += __shfl_down_sync(mask, val, offset); if (tid < warpSize) shm[tid] = 0; -- GitLab