From 6d3da458a77101e2bbbb8142db32e4d81be53ca2 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 31 Jul 2018 12:20:40 +0800 Subject: [PATCH] Fix/float16 style (#12446) * "rewrite the test case" * "follow comment" --- paddle/fluid/platform/cuda_helper_test.cu | 183 +++++++++++++--------- paddle/fluid/platform/cuda_primitives.h | 20 +-- 2 files changed, 119 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index 4a47ba5cca..ca5ca1caeb 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -13,7 +13,6 @@ // limitations under the License. #include -#include #include #include @@ -25,13 +24,13 @@ using paddle::platform::PADDLE_CUDA_NUM_THREADS; using paddle::platform::float16; -#define CUDA_ATOMIC_KERNEL(op, T) \ - __global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \ - i += blockDim.x * gridDim.x) { \ - paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \ - } \ +template +__global__ void AddKernel(const T* data_a, T* data_b, size_t num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]); } +} template struct AddFunctor { @@ -39,80 +38,116 @@ struct AddFunctor { }; template -struct SubFunctor { - T operator()(const T& a, const T& b) { return a - b; } -}; - -// NOTE(dzhwinter): the float16 add has small underflow/overflow -// so we use EXPECT_NEAR to check the result. -#define ARITHMETIC_KERNEL_LAUNCH(op, T) \ - void Test##T##op(size_t num) { \ - T *in1, *in2, *out; \ - T *d_in1, *d_in2; \ - size_t size = sizeof(T) * num; \ - cudaMalloc(reinterpret_cast(&d_in1), size); \ - cudaMalloc(reinterpret_cast(&d_in2), size); \ - in1 = reinterpret_cast(malloc(size)); \ - in2 = reinterpret_cast(malloc(size)); \ - out = reinterpret_cast(malloc(size)); \ - std::minstd_rand engine; \ - std::uniform_real_distribution dist(0.0, 1.0); \ - for (size_t i = 0; i < num; ++i) { \ - in1[i] = static_cast(dist(engine)); \ - in2[i] = static_cast(dist(engine)); \ - } \ - cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \ - cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \ - op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \ - cudaDeviceSynchronize(); \ - cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \ - cudaDeviceSynchronize(); \ - for (size_t i = 0; i < num; ++i) { \ - EXPECT_NEAR(static_cast(out[i]), \ - static_cast(op##Functor()(in1[i], in2[i])), \ - 0.001); \ - } \ - free(in1); \ - free(in2); \ - free(out); \ - cudaFree(d_in1); \ - cudaFree(d_in2); \ +void TestCase(size_t num) { + T *in1, *in2, *out; + T *d_in1, *d_in2; + size_t size = sizeof(T) * num; + cudaMalloc(reinterpret_cast(&d_in1), size); + cudaMalloc(reinterpret_cast(&d_in2), size); + in1 = reinterpret_cast(malloc(size)); + in2 = reinterpret_cast(malloc(size)); + out = reinterpret_cast(malloc(size)); + std::minstd_rand engine; + std::uniform_real_distribution dist(0.0, 1.0); + for (size_t i = 0; i < num; ++i) { + in1[i] = static_cast(dist(engine)); + in2[i] = static_cast(dist(engine)); } -CUDA_ATOMIC_KERNEL(Add, float); -CUDA_ATOMIC_KERNEL(Add, double); -CUDA_ATOMIC_KERNEL(Add, float16); - -ARITHMETIC_KERNEL_LAUNCH(Add, float); -ARITHMETIC_KERNEL_LAUNCH(Add, double); -ARITHMETIC_KERNEL_LAUNCH(Add, float16); - -namespace paddle { -namespace platform { -USE_CUDA_ATOMIC(Sub, int); -}; -}; -CUDA_ATOMIC_KERNEL(Sub, int); -ARITHMETIC_KERNEL_LAUNCH(Sub, int); + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); + cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); + AddKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); + cudaDeviceSynchronize(); + cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + for (size_t i = 0; i < num; ++i) { + // NOTE(dzhwinter): the float16 add has small underflow/overflow + // so we use EXPECT_NEAR to check the result. + EXPECT_NEAR(static_cast(out[i]), + static_cast(AddFunctor()(in1[i], in2[i])), 0.001); + } + free(in1); + free(in2); + free(out); + cudaFree(d_in1); + cudaFree(d_in2); +} // cuda primitives TEST(CudaAtomic, Add) { - TestfloatAdd(static_cast(10)); - TestfloatAdd(static_cast(1024 * 1024)); - TestdoubleAdd(static_cast(10)); - TestdoubleAdd(static_cast(1024 * 1024)); -} + TestCase(static_cast(10)); + TestCase(static_cast(1024 * 1024)); -TEST(CudaAtomic, Sub) { - TestintSub(static_cast(10)); - TestintSub(static_cast(1024 * 1024)); + TestCase(static_cast(10)); + TestCase(static_cast(1024 * 1024)); } TEST(CudaAtomic, float16) { - using paddle::platform::float16; - Testfloat16Add(static_cast(1)); - Testfloat16Add(static_cast(2)); - Testfloat16Add(static_cast(3)); + TestCase(static_cast(1)); + TestCase(static_cast(2)); + TestCase(static_cast(3)); + + TestCase(static_cast(10)); + TestCase(static_cast(1024 * 1024)); +} + +// unalignment of uint8 +void TestUnalign(size_t num, const int shift_bit) { + PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); + float16 *in1, *in2, *out; + float16 *d_in1, *d_in2; + size_t size = sizeof(uint8_t) * (num + shift_bit); + size_t array_size = sizeof(float16) * (num / 2); + + cudaMalloc(reinterpret_cast(&d_in1), size); + cudaMalloc(reinterpret_cast(&d_in2), size); + in1 = reinterpret_cast(malloc(size)); + in2 = reinterpret_cast(malloc(size)); + out = reinterpret_cast(malloc(size)); + + // right shift 1, mimic the unalignment of address + float16* r_in1 = + reinterpret_cast(reinterpret_cast(in1) + shift_bit); + float16* r_in2 = + reinterpret_cast(reinterpret_cast(in2) + shift_bit); + + std::minstd_rand engine; + std::uniform_real_distribution dist(0.0, 1.0); + for (size_t i = 0; i < num / 2; ++i) { + r_in1[i] = static_cast(dist(engine)); + r_in2[i] = static_cast(dist(engine)); + } + cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice); + cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice); + AddKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2); + cudaDeviceSynchronize(); + cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + for (size_t i = 0; i < num / 2; ++i) { + // NOTE(dzhwinter): the float16 add has small underflow/overflow + // so we use EXPECT_NEAR to check the result. + EXPECT_NEAR(static_cast(out[i]), + static_cast(AddFunctor()(r_in1[i], r_in2[i])), + 0.001); + } + free(in1); + free(in2); + free(out); + cudaFree(d_in1); + cudaFree(d_in2); +} + +TEST(CudaAtomic, float16Unalign) { + // same with float16 testcase + TestUnalign(static_cast(2), /*shift_bit*/ 2); + TestUnalign(static_cast(1024), /*shift_bit*/ 2); + TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 2); + + // shift the address. + TestUnalign(static_cast(2), /*shift_bit*/ 1); + TestUnalign(static_cast(1024), /*shift_bit*/ 1); + TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 1); - Testfloat16Add(static_cast(10)); - Testfloat16Add(static_cast(1024 * 1024)); + TestUnalign(static_cast(2), /*shift_bit*/ 3); + TestUnalign(static_cast(1024), /*shift_bit*/ 3); + TestUnalign(static_cast(1024 * 1024), /*shift_bit*/ 3); } diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 94ce83975a..67ea64833d 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -79,41 +79,41 @@ CUDA_ATOMIC_WRAPPER(Add, double) { // convert the value into float and do the add arithmetic. // then store the result into a uint32. -inline __device__ uint32_t add_to_low_half(uint32_t val, float x) { +inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) { float16 low_half; // the float16 in lower 16bits - low_half.x = static_cast(val & 0xffffu); + low_half.x = static_cast(val & 0xFFFFu); low_half = static_cast(static_cast(low_half) + x); - return (val & 0xffff0000u) | low_half.x; + return (val & 0xFFFF0000u) | low_half.x; } -inline __device__ uint32_t add_to_high_half(uint32_t val, float x) { +inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) { float16 high_half; // the float16 in higher 16bits high_half.x = static_cast(val >> 16); high_half = static_cast(static_cast(high_half) + x); - return (val & 0xffffu) | (static_cast(high_half.x) << 16); + return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); } CUDA_ATOMIC_WRAPPER(Add, float16) { // concrete packed float16 value may exsits in lower or higher 16bits // of the 32bits address. - uint32_t *address_as_ui = - reinterpret_cast(reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); float val_f = static_cast(val); uint32_t old = *address_as_ui; uint32_t sum; uint32_t newval; uint32_t assumed; - if (((size_t)address & 2) == 0) { + if (((uintptr_t)address & 0x02) == 0) { // the float16 value stay at lower 16 bits of the address. do { assumed = old; old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f)); } while (old != assumed); float16 ret; - ret.x = old & 0xffffu; + ret.x = old & 0xFFFFu; return ret; } else { // the float16 value stay at higher 16 bits of the address. -- GitLab