From 90d73c79c3c6645abf1d737cedc43341e94ce45d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 2 May 2018 11:16:44 +0800 Subject: [PATCH] fix shfl_sync for CUDA8.0 --- paddle/cuda/include/hl_base.h | 15 +++++++++++++++ paddle/cuda/src/hl_cuda_lstm.cu | 14 +++++++++----- paddle/cuda/src/hl_top_k.cu | 5 ++++- paddle/fluid/platform/cuda_primitives.h | 4 ---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/paddle/cuda/include/hl_base.h b/paddle/cuda/include/hl_base.h index 6c4f09dacb4..b979aa7723e 100644 --- a/paddle/cuda/include/hl_base.h +++ b/paddle/cuda/include/hl_base.h @@ -228,6 +228,21 @@ extern __thread cudaStream_t default_stream; << "CUDA error: " << hl_get_device_error_string((size_t)err); \ } +// __shfl has been deprecated as of CUDA 9.0. +#if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T +__shfl_sync(unsigned, T val, int src_line, int width) { + return __shfl(val, src_line, width); +} + +#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; +#else +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) +#endif + #endif /* __NVCC__ */ #endif /* HL_BASE_H_ */ diff --git a/paddle/cuda/src/hl_cuda_lstm.cu b/paddle/cuda/src/hl_cuda_lstm.cu index 38371366f8e..e30fcddffdf 100644 --- a/paddle/cuda/src/hl_cuda_lstm.cu +++ b/paddle/cuda/src/hl_cuda_lstm.cu @@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue, } __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { - int addr = idx % 32; + const int warp_size = 32; + int addr = idx % warp_size; + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, addr < warp_size); #pragma unroll for (int k = 1; k < 32; k++) { // rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32); - addr = __shfl_sync(addr, (idx + 1) % 32, 32); - a[k] = __shfl_sync(a[k], addr, 32); + addr = __shfl_sync(mask, addr, (idx + 1) % 32, 32); + a[k] = __shfl_sync(mask, a[k], addr, 32); } #pragma unroll @@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { } addr = (32 - idx) % 32; + CREATE_SHFL_MASK(mask, idx % 32 < warp_size); #pragma unroll for (int k = 0; k < 32; k++) { - a[k] = __shfl_sync(a[k], addr, 32); - addr = __shfl_sync(addr, (idx + 31) % 32, 32); + a[k] = __shfl_sync(mask, a[k], addr, 32); + addr = __shfl_sync(mask, addr, (idx + 31) % 32, 32); } } diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index 94c9cceb2c3..59ba552f560 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, if (--beamSize == 0) break; __syncthreads(); + unsigned mask = 0u; + // CREATE_SHFL_MASK(mask, tid < len); + if (tid == maxId[0]) { if (beam < maxLength) { shTopK[tid] = topK[beam]; } } if (maxId[0] / 32 == warp) { - if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break; + if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break; } } } diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 46b97043ab3..866ff30a8be 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -74,10 +74,6 @@ __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { } #define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #else -template -__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) { - return __shfl_down(mask, val, delta); -} #define FULL_WARP_MASK 0xFFFFFFFF #define CREATE_SHFL_MASK(mask, predicate) \ mask = __ballot_sync(FULL_WARP_MASK, (predicate)) -- GitLab