diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index e0d7937ae2f3ce4bda12f3771727e2992d63cb9b..a6f68f8b0c0a9b07c326888e30c0c911e7861607 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) + +IF(WITH_GPU) + nv_test(cuda_helper_test SRCS cuda_helper_test.cu) +ENDIF() diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index ecec4178f2d9937920e52eb74bf9068b84e741a0..23457ff5fe1ec27094113ba0dde26adc64c716b5 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -14,6 +14,10 @@ limitations under the License. */ #pragma once #include +// NOTE(): support float16 to half in header file. +#define PADDLE_CUDA_FP16 +#include +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace platform { @@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, #endif } +// CUDA 9.0 have native compatible float16 shfl_down +#if CUDA_VERSION < 9000 +template <> +__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, + float16 val, int delta, + int width) { + half tmp = static_cast(val); + __shfl_down(tmp, static_cast(delta), width); + return float16(tmp); +} +#endif + template __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { @@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line, #endif } +template +HOSTDEVICE T Infinity() { + return INFINITY; +} + template __device__ T reduceSum(T val, int tid, int len) { // NOTE(zcd): The warp size should be taken from the diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a47ba5ccad4de338844e60f6fcbd6b7c11e891b --- /dev/null +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -0,0 +1,118 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#define PADDLE_CUDA_FP16 +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; +using paddle::platform::float16; + +#define CUDA_ATOMIC_KERNEL(op, T) \ + __global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \ + i += blockDim.x * gridDim.x) { \ + paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \ + } \ + } + +template +struct AddFunctor { + T operator()(const T& a, const T& b) { return a + b; } +}; + +template +struct SubFunctor { + T operator()(const T& a, const T& b) { return a - b; } +}; + +// NOTE(dzhwinter): the float16 add has small underflow/overflow +// so we use EXPECT_NEAR to check the result. +#define ARITHMETIC_KERNEL_LAUNCH(op, T) \ + void Test##T##op(size_t num) { \ + T *in1, *in2, *out; \ + T *d_in1, *d_in2; \ + size_t size = sizeof(T) * num; \ + cudaMalloc(reinterpret_cast(&d_in1), size); \ + cudaMalloc(reinterpret_cast(&d_in2), size); \ + in1 = reinterpret_cast(malloc(size)); \ + in2 = reinterpret_cast(malloc(size)); \ + out = reinterpret_cast(malloc(size)); \ + std::minstd_rand engine; \ + std::uniform_real_distribution dist(0.0, 1.0); \ + for (size_t i = 0; i < num; ++i) { \ + in1[i] = static_cast(dist(engine)); \ + in2[i] = static_cast(dist(engine)); \ + } \ + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \ + cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \ + op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \ + cudaDeviceSynchronize(); \ + cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \ + cudaDeviceSynchronize(); \ + for (size_t i = 0; i < num; ++i) { \ + EXPECT_NEAR(static_cast(out[i]), \ + static_cast(op##Functor()(in1[i], in2[i])), \ + 0.001); \ + } \ + free(in1); \ + free(in2); \ + free(out); \ + cudaFree(d_in1); \ + cudaFree(d_in2); \ + } +CUDA_ATOMIC_KERNEL(Add, float); +CUDA_ATOMIC_KERNEL(Add, double); +CUDA_ATOMIC_KERNEL(Add, float16); + +ARITHMETIC_KERNEL_LAUNCH(Add, float); +ARITHMETIC_KERNEL_LAUNCH(Add, double); +ARITHMETIC_KERNEL_LAUNCH(Add, float16); + +namespace paddle { +namespace platform { +USE_CUDA_ATOMIC(Sub, int); +}; +}; +CUDA_ATOMIC_KERNEL(Sub, int); +ARITHMETIC_KERNEL_LAUNCH(Sub, int); + +// cuda primitives +TEST(CudaAtomic, Add) { + TestfloatAdd(static_cast(10)); + TestfloatAdd(static_cast(1024 * 1024)); + TestdoubleAdd(static_cast(10)); + TestdoubleAdd(static_cast(1024 * 1024)); +} + +TEST(CudaAtomic, Sub) { + TestintSub(static_cast(10)); + TestintSub(static_cast(1024 * 1024)); +} + +TEST(CudaAtomic, float16) { + using paddle::platform::float16; + Testfloat16Add(static_cast(1)); + Testfloat16Add(static_cast(2)); + Testfloat16Add(static_cast(3)); + + Testfloat16Add(static_cast(10)); + Testfloat16Add(static_cast(1024 * 1024)); +} diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index d535ed2f89df6a0b311ec068ecd92c8e3183cee7..94ce83975a7f13daa2b6a4d480cb22cc95811b9b 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -14,12 +14,14 @@ limitations under the License. */ #pragma once #include +#include +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace platform { #define CUDA_ATOMIC_WRAPPER(op, T) \ - __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) + __device__ __forceinline__ T CudaAtomic##op(T *address, const T val) #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } @@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) { static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT "long long should be int64"); return CudaAtomicAdd( - reinterpret_cast(address), // NOLINT - static_cast(val)); // NOLINT + reinterpret_cast(address), // NOLINT + static_cast(val)); // NOLINT } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 USE_CUDA_ATOMIC(Add, double); #else CUDA_ATOMIC_WRAPPER(Add, double) { - unsigned long long int* address_as_ull = // NOLINT - reinterpret_cast(address); // NOLINT - unsigned long long int old = *address_as_ull, assumed; // NOLINT + unsigned long long int *address_as_ull = // NOLINT + reinterpret_cast(address); // NOLINT + unsigned long long int old = *address_as_ull, assumed; // NOLINT do { assumed = old; @@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } +#endif + +#ifdef PADDLE_CUDA_FP16 +// NOTE(dzhwinter): cuda do not have atomicCAS for half. +// Just use the half address as a unsigned value address and +// do the atomicCAS. According to the value store at high 16 bits +// or low 16 bits, then do a different sum and CAS. +// Given most warp-threads will failed on the atomicCAS, so this +// implemented should be avoided in high concurrency. It's will be +// slower than the way convert value into 32bits and do a full atomicCAS. + +// convert the value into float and do the add arithmetic. +// then store the result into a uint32. +inline __device__ uint32_t add_to_low_half(uint32_t val, float x) { + float16 low_half; + // the float16 in lower 16bits + low_half.x = static_cast(val & 0xffffu); + low_half = static_cast(static_cast(low_half) + x); + return (val & 0xffff0000u) | low_half.x; +} + +inline __device__ uint32_t add_to_high_half(uint32_t val, float x) { + float16 high_half; + // the float16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(static_cast(high_half) + x); + return (val & 0xffffu) | (static_cast(high_half.x) << 16); +} + +CUDA_ATOMIC_WRAPPER(Add, float16) { + // concrete packed float16 value may exsits in lower or higher 16bits + // of the 32bits address. + uint32_t *address_as_ui = + reinterpret_cast(reinterpret_cast(address) - + (reinterpret_cast(address) & 2)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t sum; + uint32_t newval; + uint32_t assumed; + if (((size_t)address & 2) == 0) { + // the float16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old & 0xffffu; + return ret; + } else { + // the float16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old >> 16; + return ret; + } +} + #endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index ffd183af68514dbb1a8b3de39000c9ca3f56ddc3..efb021c838e3680ab2cdd1c4b298cf7ec2186478 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -67,8 +67,11 @@ struct float16; } // namespace platform } // namespace paddle +// NOTE(): +// Do not move the eigen.h header, otherwise the eigen_vector will failed. #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/hostdevice.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { @@ -898,6 +901,30 @@ struct is_pod { is_standard_layout::value; }; +template <> +struct is_floating_point + : std::integral_constant< + bool, std::is_same::type>::value> {}; +template <> +struct is_signed { + static const bool value = true; +}; + +template <> +struct is_unsigned { + static const bool value = false; +}; + +inline bool isnan(const paddle::platform::float16& a) { + return paddle::platform::isnan(a); +} + +inline bool isinf(const paddle::platform::float16& a) { + return paddle::platform::isinf(a); +} + template <> struct numeric_limits { static const bool is_specialized = true; diff --git a/paddle/fluid/platform/float16_test.cc b/paddle/fluid/platform/float16_test.cc index ede294be1e2e26693bd3ead2ccd5e6a6c8a075bc..27e930e6e0a76982b3f27619f38a4a08d82cafa1 100644 --- a/paddle/fluid/platform/float16_test.cc +++ b/paddle/fluid/platform/float16_test.cc @@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) { } } +TEST(float16, floating) { + // compile time assert. + PADDLE_ASSERT(std::is_floating_point::value); +} + TEST(float16, print) { float16 a = float16(1.0f); std::cout << a << std::endl; } +// CPU test +TEST(float16, isinf) { + float16 a; + a.x = 0x7c00; + float16 b = float16(INFINITY); + float16 c = static_cast(INFINITY); + EXPECT_EQ(std::isinf(a), true); + EXPECT_EQ(std::isinf(b), true); + EXPECT_EQ(std::isinf(c), true); +} + +TEST(float16, isnan) { + float16 a; + a.x = 0x7fff; + float16 b = float16(NAN); + float16 c = static_cast(NAN); + EXPECT_EQ(std::isnan(a), true); + EXPECT_EQ(std::isnan(b), true); + EXPECT_EQ(std::isnan(c), true); +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index 1b9cf9b5d3fa2121b588c31d7cf2f4c50cb951bc..e2b7ca9b03809113c31af8ff4d3ad3713748f330 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -11,11 +11,13 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" +#include #include +#include +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/legacy/utils/Logging.h" #define ARITHMETIC_KERNEL(op_type, sign) \ __global__ void op_type(const half* in1, const half* in2, half* out) { \ @@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) { } } +template +struct Functor { + bool operator()(const T& val) { + return std::type_index(typeid(T)) == + std::type_index(typeid(platform::float16)); + } +}; + +TEST(float16, typeid) { + // the framework heavily used typeid hash + Functor functor; + float16 a = float16(.0f); + Functor functor2; + int b(0); + + // compile time assert + PADDLE_ASSERT(functor(a) == true); + PADDLE_ASSERT(functor2(b) == false); +} + +// GPU test +TEST(float16, isinf) { + float16 a; + a.x = 0x7c00; + float16 b = float16(INFINITY); + // underflow to 0 + float16 native_a(5e-40f); + // overflow to inf + float16 native_b(5e40f); + EXPECT_EQ(std::isinf(a), true); + EXPECT_EQ(std::isinf(b), true); + EXPECT_EQ(std::isinf(native_b), true); + EXPECT_EQ(native_a, float16(0)); +} + +TEST(float16, isnan) { + float16 a; + a.x = 0x7fff; + float16 b = float16(NAN); + float16 c = float16(5e40); + // inf * +-0 will get a nan + float16 d = c * float16(0); + EXPECT_EQ(std::isnan(a), true); + EXPECT_EQ(std::isnan(b), true); + EXPECT_EQ(std::isnan(d), true); +} + +TEST(float16, cast) { + float16 a; + a.x = 0x0070; + auto b = a; + { + // change semantic, keep the same value + float16 c = reinterpret_cast(reinterpret_cast(b)); + EXPECT_EQ(b, c); + } + + { + // use uint32 low 16 bit store float16 + uint32_t c = reinterpret_cast(b); + float16 d; + d.x = c; + EXPECT_EQ(b, d); + } +} + } // namespace platform } // namespace paddle #endif // PADDLE_CUDA_FP16