// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #define PADDLE_CUDA_FP16 #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" using paddle::platform::PADDLE_CUDA_NUM_THREADS; using paddle::platform::float16; template __global__ void AddKernel(const T* data_a, T* data_b, size_t num) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]); } } template struct AddFunctor { T operator()(const T& a, const T& b) { return a + b; } }; template void TestCase(size_t num) { T *in1, *in2, *out; T *d_in1, *d_in2; size_t size = sizeof(T) * num; cudaMalloc(reinterpret_cast(&d_in1), size); cudaMalloc(reinterpret_cast(&d_in2), size); in1 = reinterpret_cast(malloc(size)); in2 = reinterpret_cast(malloc(size)); out = reinterpret_cast(malloc(size)); std::minstd_rand engine; std::uniform_real_distribution dist(0.0, 1.0); for (size_t i = 0; i < num; ++i) { in1[i] = static_cast(dist(engine)); in2[i] = static_cast(dist(engine)); } cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); AddKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); cudaDeviceSynchronize(); cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); cudaDeviceSynchronize(); for (size_t i = 0; i < num; ++i) { // NOTE(dzhwinter): the float16 add has small underflow/overflow // so we use EXPECT_NEAR to check the result. EXPECT_NEAR(static_cast(out[i]), static_cast(AddFunctor()(in1[i], in2[i])), 0.001); } free(in1); free(in2); free(out); cudaFree(d_in1); cudaFree(d_in2); } // cuda primitives TEST(CudaAtomic, Add) { TestCase(static_cast(10)); TestCase(static_cast(1024 * 1024)); TestCase(static_cast(10)); TestCase(static_cast(1024 * 1024)); } TEST(CudaAtomic, float16) { TestCase(static_cast(1)); TestCase(static_cast(2)); TestCase(static_cast(3)); TestCase(static_cast(10)); TestCase(static_cast(1024 * 1024)); } // unalignment of uint8 void TestUnalign(size_t num, const int shift_bit) { PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); float16 *in1, *in2, *out; float16 *d_in1, *d_in2; size_t size = sizeof(uint8_t) * (num + shift_bit); size_t array_size = sizeof(float16) * (num / 2); cudaMalloc(reinterpret_cast(&d_in1), size); cudaMalloc(reinterpret_cast(&d_in2), size); in1 = reinterpret_cast(malloc(size)); in2 = reinterpret_cast(malloc(size)); out = reinterpret_cast(malloc(size)); // right shift 1, mimic the unalignment of address float16* r_in1 = reinterpret_cast(reinterpret_cast(in1) + shift_bit); float16* r_in2 = reinterpret_cast(reinterpret_cast(in2) + shift_bit); std::minstd_rand engine; std::uniform_real_distribution dist(0.0, 1.0); for (size_t i = 0; i < num / 2; ++i) { r_in1[i] = static_cast(dist(engine)); r_in2[i] = static_cast(dist(engine)); } cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice); cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice); AddKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2); cudaDeviceSynchronize(); cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost); cudaDeviceSynchronize(); for (size_t i = 0; i < num / 2; ++i) { // NOTE(dzhwinter): the float16 add has small truncate error. // so we use EXPECT_NEAR to check the result. EXPECT_NEAR(static_cast(out[i]), static_cast(AddFunctor()(r_in1[i], r_in2[i])), 0.001); } free(in1); free(in2); free(out); cudaFree(d_in1); cudaFree(d_in2); } TEST(CudaAtomic, float16Unalign) { // same with float16 testcase TestUnalign(static_cast(2), /*shift_bit*/ 2); TestUnalign(static_cast(1024), /*shift_bit*/ 2); TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 2); // shift the address. TestUnalign(static_cast(2), /*shift_bit*/ 1); TestUnalign(static_cast(1024), /*shift_bit*/ 1); TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 1); TestUnalign(static_cast(2), /*shift_bit*/ 3); TestUnalign(static_cast(1024), /*shift_bit*/ 3); TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 3); } // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ template static __forceinline__ __device__ T WarpReduceSum(T val) { unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); for (int offset = warpSize / 2; offset > 0; offset /= 2) { val += paddle::platform::CudaShuffleDownSync(mask, val, offset); } return val; } template __forceinline__ __device__ T BlockReduce(T val) { static __shared__ T shared[32]; // Shared mem for 32 partial sums int lane = threadIdx.x % warpSize; int wid = threadIdx.x / warpSize; val = WarpReduceSum(val); // Each warp performs partial reduction if (lane == 0) shared[wid] = val; // Write reduced value to shared memory __syncthreads(); // Wait for all partial reductions // read from shared memory only if that warp existed val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast(0); if (wid == 0) val = WarpReduceSum(val); // Final reduce within first warp return val; } template __global__ void DeviceReduceSum(T* in, T* out, size_t N) { T sum(0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum += in[i]; } sum = BlockReduce(sum); __syncthreads(); if (threadIdx.x == 0) out[blockIdx.x] = sum; } template void TestReduce(size_t num, float atol = 0.01) { T* in1; T *d_in1, *d_in2; size_t size = sizeof(T) * num; cudaMalloc(reinterpret_cast(&d_in1), size); cudaMalloc(reinterpret_cast(&d_in2), sizeof(T)); in1 = reinterpret_cast(malloc(size)); std::minstd_rand engine; std::uniform_real_distribution dist(0.0, 1.0); for (size_t i = 0; i < num; ++i) { in1[i] = static_cast(dist(engine)); } auto out = std::accumulate(in1, in1 + num, static_cast(0)); cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); cudaDeviceSynchronize(); DeviceReduceSum<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); cudaMemcpy(in1, d_in2, sizeof(T), cudaMemcpyDeviceToHost); cudaDeviceSynchronize(); // NOTE(dzhwinter): the float16 add has small underflow/overflow // so we use EXPECT_NEAR to check the result. EXPECT_NEAR(static_cast(in1[0]), static_cast(out), atol); free(in1); cudaFree(d_in1); cudaFree(d_in2); } TEST(CudaShuffleSync, float16) { TestReduce(10); TestReduce(1000); // float16 will overflow or accumulate truncate errors in big size. TestReduce(10); TestReduce(100, /*atol error*/ 1.0); }