未验证 提交 39ac9e39 编写于 作者: D dzhwinter 提交者: GitHub

float16 type support enhance (#12181)

* cherry picked

* "cherry picked platform"

* "add comment"

* "fix ci"
上级 19ef4bab
......@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
IF(WITH_GPU)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
ENDIF()
......@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once
#include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace platform {
......@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#endif
}
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
float16 val, int delta,
int width) {
half tmp = static_cast<half>(val);
__shfl_down(tmp, static_cast<unsigned>(delta), width);
return float16(tmp);
}
#endif
template <typename T>
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
int width = 32) {
......@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
#endif
}
template <typename T>
HOSTDEVICE T Infinity() {
return INFINITY;
}
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include <random>
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
using paddle::platform::float16;
#define CUDA_ATOMIC_KERNEL(op, T) \
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
i += blockDim.x * gridDim.x) { \
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
} \
}
template <typename T>
struct AddFunctor {
T operator()(const T& a, const T& b) { return a + b; }
};
template <typename T>
struct SubFunctor {
T operator()(const T& a, const T& b) { return a - b; }
};
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
void Test##T##op(size_t num) { \
T *in1, *in2, *out; \
T *d_in1, *d_in2; \
size_t size = sizeof(T) * num; \
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
in1 = reinterpret_cast<T*>(malloc(size)); \
in2 = reinterpret_cast<T*>(malloc(size)); \
out = reinterpret_cast<T*>(malloc(size)); \
std::minstd_rand engine; \
std::uniform_real_distribution<double> dist(0.0, 1.0); \
for (size_t i = 0; i < num; ++i) { \
in1[i] = static_cast<T>(dist(engine)); \
in2[i] = static_cast<T>(dist(engine)); \
} \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
cudaDeviceSynchronize(); \
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
cudaDeviceSynchronize(); \
for (size_t i = 0; i < num; ++i) { \
EXPECT_NEAR(static_cast<float>(out[i]), \
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
0.001); \
} \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
CUDA_ATOMIC_KERNEL(Add, float);
CUDA_ATOMIC_KERNEL(Add, double);
CUDA_ATOMIC_KERNEL(Add, float16);
ARITHMETIC_KERNEL_LAUNCH(Add, float);
ARITHMETIC_KERNEL_LAUNCH(Add, double);
ARITHMETIC_KERNEL_LAUNCH(Add, float16);
namespace paddle {
namespace platform {
USE_CUDA_ATOMIC(Sub, int);
};
};
CUDA_ATOMIC_KERNEL(Sub, int);
ARITHMETIC_KERNEL_LAUNCH(Sub, int);
// cuda primitives
TEST(CudaAtomic, Add) {
TestfloatAdd(static_cast<size_t>(10));
TestfloatAdd(static_cast<size_t>(1024 * 1024));
TestdoubleAdd(static_cast<size_t>(10));
TestdoubleAdd(static_cast<size_t>(1024 * 1024));
}
TEST(CudaAtomic, Sub) {
TestintSub(static_cast<size_t>(10));
TestintSub(static_cast<size_t>(1024 * 1024));
}
TEST(CudaAtomic, float16) {
using paddle::platform::float16;
Testfloat16Add(static_cast<size_t>(1));
Testfloat16Add(static_cast<size_t>(2));
Testfloat16Add(static_cast<size_t>(3));
Testfloat16Add(static_cast<size_t>(10));
Testfloat16Add(static_cast<size_t>(1024 * 1024));
}
......@@ -14,12 +14,14 @@ limitations under the License. */
#pragma once
#include <cuda.h>
#include <stdio.h>
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace platform {
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
__device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
......@@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
return CudaAtomicAdd(
reinterpret_cast<unsigned long long int*>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
reinterpret_cast<unsigned long long int *>(address), // NOLINT
static_cast<unsigned long long int>(val)); // NOLINT
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int* address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int*>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
unsigned long long int *address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT
do {
assumed = old;
......@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return __longlong_as_double(old);
}
#endif
#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
inline __device__ uint32_t add_to_low_half(uint32_t val, float x) {
float16 low_half;
// the float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xffffu);
low_half = static_cast<float16>(static_cast<float>(low_half) + x);
return (val & 0xffff0000u) | low_half.x;
}
inline __device__ uint32_t add_to_high_half(uint32_t val, float x) {
float16 high_half;
// the float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<float16>(static_cast<float>(high_half) + x);
return (val & 0xffffu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Add, float16) {
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui =
reinterpret_cast<uint32_t *>(reinterpret_cast<char *>(address) -
(reinterpret_cast<size_t>(address) & 2));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((size_t)address & 2) == 0) {
// the float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old & 0xffffu;
return ret;
} else {
// the float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
} while (old != assumed);
float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif
} // namespace platform
} // namespace paddle
......@@ -67,8 +67,11 @@ struct float16;
} // namespace platform
} // namespace paddle
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace platform {
......@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
is_standard_layout<paddle::platform::float16>::value;
};
template <>
struct is_floating_point<paddle::platform::float16>
: std::integral_constant<
bool, std::is_same<paddle::platform::float16,
typename std::remove_cv<
paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
static const bool value = true;
};
template <>
struct is_unsigned<paddle::platform::float16> {
static const bool value = false;
};
inline bool isnan(const paddle::platform::float16& a) {
return paddle::platform::isnan(a);
}
inline bool isinf(const paddle::platform::float16& a) {
return paddle::platform::isinf(a);
}
template <>
struct numeric_limits<paddle::platform::float16> {
static const bool is_specialized = true;
......
......@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
}
}
TEST(float16, floating) {
// compile time assert.
PADDLE_ASSERT(std::is_floating_point<float16>::value);
}
TEST(float16, print) {
float16 a = float16(1.0f);
std::cout << a << std::endl;
}
// CPU test
TEST(float16, isinf) {
float16 a;
a.x = 0x7c00;
float16 b = float16(INFINITY);
float16 c = static_cast<float16>(INFINITY);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(c), true);
}
TEST(float16, isnan) {
float16 a;
a.x = 0x7fff;
float16 b = float16(NAN);
float16 c = static_cast<float16>(NAN);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(c), true);
}
} // namespace platform
} // namespace paddle
......@@ -11,11 +11,13 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/legacy/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, half* out) { \
......@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) {
}
}
template <typename T>
struct Functor {
bool operator()(const T& val) {
return std::type_index(typeid(T)) ==
std::type_index(typeid(platform::float16));
}
};
TEST(float16, typeid) {
// the framework heavily used typeid hash
Functor<float16> functor;
float16 a = float16(.0f);
Functor<int> functor2;
int b(0);
// compile time assert
PADDLE_ASSERT(functor(a) == true);
PADDLE_ASSERT(functor2(b) == false);
}
// GPU test
TEST(float16, isinf) {
float16 a;
a.x = 0x7c00;
float16 b = float16(INFINITY);
// underflow to 0
float16 native_a(5e-40f);
// overflow to inf
float16 native_b(5e40f);
EXPECT_EQ(std::isinf(a), true);
EXPECT_EQ(std::isinf(b), true);
EXPECT_EQ(std::isinf(native_b), true);
EXPECT_EQ(native_a, float16(0));
}
TEST(float16, isnan) {
float16 a;
a.x = 0x7fff;
float16 b = float16(NAN);
float16 c = float16(5e40);
// inf * +-0 will get a nan
float16 d = c * float16(0);
EXPECT_EQ(std::isnan(a), true);
EXPECT_EQ(std::isnan(b), true);
EXPECT_EQ(std::isnan(d), true);
}
TEST(float16, cast) {
float16 a;
a.x = 0x0070;
auto b = a;
{
// change semantic, keep the same value
float16 c = reinterpret_cast<float16&>(reinterpret_cast<unsigned&>(b));
EXPECT_EQ(b, c);
}
{
// use uint32 low 16 bit store float16
uint32_t c = reinterpret_cast<uint32_t&>(b);
float16 d;
d.x = c;
EXPECT_EQ(b, d);
}
}
} // namespace platform
} // namespace paddle
#endif // PADDLE_CUDA_FP16
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册