diff --git a/paddle/fluid/operators/math/math_cuda_utils.h b/paddle/fluid/operators/math/math_cuda_utils.h index b9afd2d39d0448e7a375bafd7148661ab404d662..fbb84226478937f056b1322326b04b90dcc3f02e 100644 --- a/paddle/fluid/operators/math/math_cuda_utils.h +++ b/paddle/fluid/operators/math/math_cuda_utils.h @@ -211,6 +211,39 @@ __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { return val; } +template +__inline__ __device__ T warpReduceMin(T val, unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); +#else + val = min(val, __shfl_xor(val, mask, warpSize)); +#endif + return val; +} + +/* Calculate the minimum of all elements in a warp when actual quantity of + * threads are less than warpSize.*/ +template +__inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + T warp_val = __shfl_sync(lane_mask, val, 0, warpSize); +#else + T warp_val = __shfl( + val, 0, warpSize); // To fullfill the data in each thread of this warp. +#endif + warp_val = val; + + for (int offset = HALF_WARP; offset > 0; offset >>= 1) +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + warp_val = + min(warp_val, __shfl_down_sync(lane_mask, warp_val, offset, warpSize)); +#else + warp_val = min(warp_val, __shfl_down(warp_val, offset, warpSize)); +#endif + return warp_val; +} + /* Calculate the maximum of all elements in a block */ template __inline__ __device__ T blockReduceMax(T val, unsigned mask) { @@ -232,6 +265,49 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) { return val; } +/* Calculate the minimum of all elements in a block */ +template +__inline__ __device__ T blockReduceMin(T val, unsigned mask) { + static __shared__ T shared[WARP_SIZE]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceMin(val, mask); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + // align block_span to warpSize + int block_span = (blockDim.x + warpSize - 1) >> 5; + val = (lane < block_span) ? shared[lane] : 1e10f; + val = warpReduceMin(val, mask); + + return val; +} + +/* Calculate the minimum of all elements in a warp when actual quantity of + * threads are less than warpSize.*/ +template +__inline__ __device__ T PartialBlockReduceMin(T val, unsigned mask) { + static __shared__ T shared[WARP_SIZE]; + static __shared__ T min_value; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = PartialWarpReduceMin(val, mask); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + shared[lane] = PartialWarpReduceMin(shared[lane], mask); + __syncwarp(); + +#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 + val = __shfl_sync(mask, shared[lane], 0, warpSize); +#else + val = __shfl(shared[lane], 0, warpSize); +#endif + return val; +} + } // namespace math } // namespace operators } // namespace paddle