diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index 23457ff5fe1ec27094113ba0dde26adc64c716b5..9f504d14a8da116648483c0f64cb511b46e6a97e 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -36,7 +36,7 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, #if CUDA_VERSION < 9000 return __shfl_down(val, delta, width); #else - return __shfl_down_sync(mask, val, delta, width); + return __shfl_down_sync(mask, val, static_cast(delta), width); #endif } @@ -46,9 +46,16 @@ template <> __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, float16 val, int delta, int width) { - half tmp = static_cast(val); - __shfl_down(tmp, static_cast(delta), width); - return float16(tmp); + return float16( + __shfl_down(static_cast(val), static_cast(delta), width)); +} +#else +template <> +__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, + float16 val, int delta, + int width) { + return float16(__shfl_down_sync(mask, static_cast(val), + static_cast(delta), width)); } #endif diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index ca5ca1caeb23f01c047feeccf9c276b2dcd1cb68..ee45afab93d079374aefe366425502890854c28d 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include @@ -123,7 +124,7 @@ void TestUnalign(size_t num, const int shift_bit) { cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost); cudaDeviceSynchronize(); for (size_t i = 0; i < num / 2; ++i) { - // NOTE(dzhwinter): the float16 add has small underflow/overflow + // NOTE(dzhwinter): the float16 add has small truncate error. // so we use EXPECT_NEAR to check the result. EXPECT_NEAR(static_cast(out[i]), static_cast(AddFunctor()(r_in1[i], r_in2[i])), @@ -151,3 +152,83 @@ TEST(CudaAtomic, float16Unalign) { TestUnalign(static_cast(1024), /*shift_bit*/ 3); TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 3); } + +// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ +template +static __forceinline__ __device__ T WarpReduceSum(T val) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val += paddle::platform::CudaShuffleDownSync(mask, val, offset); + } + return val; +} + +template +__forceinline__ __device__ T BlockReduce(T val) { + static __shared__ T shared[32]; // Shared mem for 32 partial sums + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = WarpReduceSum(val); // Each warp performs partial reduction + + if (lane == 0) shared[wid] = val; // Write reduced value to shared memory + + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + val = + (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast(0); + + if (wid == 0) val = WarpReduceSum(val); // Final reduce within first warp + + return val; +} + +template +__global__ void DeviceReduceSum(T* in, T* out, size_t N) { + T sum(0); + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + sum += in[i]; + } + sum = BlockReduce(sum); + __syncthreads(); + if (threadIdx.x == 0) out[blockIdx.x] = sum; +} + +template +void TestReduce(size_t num, float atol = 0.01) { + T* in1; + T *d_in1, *d_in2; + size_t size = sizeof(T) * num; + cudaMalloc(reinterpret_cast(&d_in1), size); + cudaMalloc(reinterpret_cast(&d_in2), sizeof(T)); + in1 = reinterpret_cast(malloc(size)); + std::minstd_rand engine; + std::uniform_real_distribution dist(0.0, 1.0); + for (size_t i = 0; i < num; ++i) { + in1[i] = static_cast(dist(engine)); + } + auto out = std::accumulate(in1, in1 + num, static_cast(0)); + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + DeviceReduceSum<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); + cudaMemcpy(in1, d_in2, sizeof(T), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + // NOTE(dzhwinter): the float16 add has small underflow/overflow + // so we use EXPECT_NEAR to check the result. + EXPECT_NEAR(static_cast(in1[0]), static_cast(out), atol); + free(in1); + cudaFree(d_in1); + cudaFree(d_in2); +} + +TEST(CudaShuffleSync, float16) { + TestReduce(10); + TestReduce(1000); + + // float16 will overflow or accumulate truncate errors in big size. + TestReduce(10); + TestReduce(100, /*atol error*/ 1.0); +}