From a82883922eefd80a0f7139cfe317a592ffd24645 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 24 Feb 2018 16:14:12 +0800 Subject: [PATCH] follow comments --- .../fluid/operators/elementwise_op_function.h | 71 ++++--------------- paddle/fluid/platform/cuda_helper.h | 39 ++++++++++ 2 files changed, 52 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index a6c73598e00..89050ec27af 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -20,6 +20,7 @@ limitations under the License. */ #ifdef __NVCC__ #include +#include "paddle/fluid/platform/cuda_helper.h" constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif @@ -357,25 +358,14 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, } } #ifdef __NVCC__ -// __shfl_down has been deprecated as of CUDA 9.0 -#if CUDA_VERSION < 9000 -template -__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { - return __shfl_down(val, delta); -} -#endif - template static __global__ void ElemwiseGradBroadcast1CUDAKernel( const T* x, const T* y, const T* out, const T* dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { - extern __shared__ char shm_buffer[]; - T* shm = reinterpret_cast(shm_buffer); - int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - shm[tid] = 0; + T val = 0; do { int x_offset = i * w + j; @@ -383,29 +373,15 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); } if (dy) { - shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); } i += ELEMWISE_MAX_BLOCK_DIM; } while (i < h); if (dy) { - T val = shm[threadIdx.x]; - int warpSize = 32; - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(0, val, offset); - - __syncthreads(); - shm[tid] = 0; - if (threadIdx.x % 32 == 0) { - shm[threadIdx.x / 32] = val; - } - - val = shm[threadIdx.x]; - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(0, val, offset); - + val = platform::ReduceSum(val, tid); if (threadIdx.x == 0) { - dy[j] = shm[0]; + dy[j] = val; } } } @@ -417,10 +393,8 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, T* dx, T* dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int gird_size = w; - int shared_mem_size = block_size * sizeof(T); - ElemwiseGradBroadcast1CUDAKernel<<>>(x, y, out, dout, h, w, dx_op, - dy_op, dx, dy); + ElemwiseGradBroadcast1CUDAKernel<<>>( + x, y, out, dout, h, w, dx_op, dy_op, dx, dy); } #endif @@ -451,7 +425,6 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, } #ifdef __NVCC__ - template static __global__ void ElemwiseGradBroadcast2CUDAKernel( const T* x, const T* y, const T* out, const T* dout, int pre, int n, @@ -459,9 +432,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( int tid = threadIdx.x; int j = blockIdx.x; - extern __shared__ char shm_buffer[]; - T* shm = reinterpret_cast(shm_buffer); - shm[tid] = 0; + T val = 0; int ttid = tid; while (true) { @@ -476,30 +447,16 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( } if (dy != nullptr) { - shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); } ttid += ELEMWISE_MAX_BLOCK_DIM; } if (dy) { - T val = shm[threadIdx.x]; - int warpSize = 32; - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(0, val, offset); - - __syncthreads(); - shm[tid] = 0; - if (threadIdx.x % 32 == 0) { - shm[threadIdx.x / 32] = val; - } - - val = shm[threadIdx.x]; - for (int offset = warpSize / 2; offset > 0; offset /= 2) - val += __shfl_down_sync(0, val, offset); - + val = platform::ReduceSum(val, threadIdx.x); if (threadIdx.x == 0) { - dy[j] = shm[0]; + dy[j] = val; } } } @@ -511,10 +468,8 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, DY_OP dy_op, T* dx, T* dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int gird_size = n; - int shared_mem_size = block_size * sizeof(T); - ElemwiseGradBroadcast2CUDAKernel<<>>(x, y, out, dout, pre, n, post, - dx_op, dy_op, dx, dy); + ElemwiseGradBroadcast2CUDAKernel<<>>( + x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy); } #endif diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 881d611d4ac..7b6ad1eb205 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -62,5 +62,44 @@ CUDA_ATOMIC_WRAPPER(Add, double) { } #endif +// __shfl_down has been deprecated as of CUDA 9.0. +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} +#define CREATE_SHFL_MASK(mask, predicate) unsigned mask = 0u; +#else +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + unsigned mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +#endif + +template +__device__ T ReduceSum(T val, int tid) { + __shared__ T shm[32]; + const int warpSize = 32; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(-1U, val, offset); + + if (tid < warpSize) shm[tid] = 0; + + __syncthreads(); + + if (tid % warpSize == 0) { + shm[tid / warpSize] = val; + } + + CREATE_SHFL_MASK(mask, tid < warpSize); + + if (tid < warpSize) { + val = shm[tid]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(mask, val, offset); + } + + return val; +} + } // namespace platform } // namespace paddle -- GitLab