float16.h 28.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

D
dzhwinter 已提交
41
#if !defined(_WIN32)
K
Kexin Zhao 已提交
42
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
43
#else
P
peizhilin 已提交
44
#define PADDLE_ALIGN(x) __declspec(align(x))
D
dzhwinter 已提交
45
#endif
K
Kexin Zhao 已提交
46 47

namespace paddle {
K
kexinzhao 已提交
48
namespace platform {
K
Kexin Zhao 已提交
49

50 51 52 53 54 55 56
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/platform/hostdevice.h"
57
#include "unsupported/Eigen/CXX11/Tensor"
58 59 60 61

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
62 63 64
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
65
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
66
struct PADDLE_ALIGN(2) float16 {
67
 public:
K
Kexin Zhao 已提交
68
  uint16_t x;
K
Kexin Zhao 已提交
69

K
kexinzhao 已提交
70 71
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
72 73 74 75 76 77
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
78 79

// Constructors
K
Kexin Zhao 已提交
80
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
81
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
82
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
83
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
84 85 86 87 88 89
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
90
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
91

K
Kexin Zhao 已提交
92
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
93
  // __fp16 is a native half precision data type for arm cpu,
94
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
95 96
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
97 98 99
  }
#endif

K
Kexin Zhao 已提交
100 101 102 103
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
104

105
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
106 107 108
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
109

K
Kexin Zhao 已提交
110 111
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
112

K
Kexin Zhao 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
130

K
Kexin Zhao 已提交
131
#endif
K
Kexin Zhao 已提交
132 133
  }

134 135 136 137
  HOSTDEVICE inline float16(int32_t val) : float16(static_cast<float>(val)) {}

  HOSTDEVICE inline float16(uint32_t val) : float16(static_cast<float>(val)) {}

K
Kexin Zhao 已提交
138
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
139

K
Kexin Zhao 已提交
140 141 142
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
143

144
// Assignment operators
K
Kexin Zhao 已提交
145
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
146
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
147
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
148
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
149 150 151 152 153 154 155
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
156
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
157 158 159 160
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
161 162 163
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
164 165 166 167
    return *this;
  }
#endif

K
Kexin Zhao 已提交
168
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
169 170 171 172
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
173 174
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
175
    return *this;
K
Kexin Zhao 已提交
176 177
  }

K
Kexin Zhao 已提交
178 179
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
180 181 182
    return *this;
  }

K
Kexin Zhao 已提交
183 184
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
185 186 187
    return *this;
  }

K
Kexin Zhao 已提交
188 189
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
190 191 192
    return *this;
  }

K
Kexin Zhao 已提交
193 194
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
195 196 197
    return *this;
  }

K
Kexin Zhao 已提交
198 199
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
200 201 202
    return *this;
  }

K
Kexin Zhao 已提交
203 204
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
205 206 207
    return *this;
  }

K
Kexin Zhao 已提交
208 209
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
210 211 212
    return *this;
  }

K
Kexin Zhao 已提交
213 214
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
215 216 217
    return *this;
  }

K
Kexin Zhao 已提交
218 219
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
220
    return *this;
K
Kexin Zhao 已提交
221
  }
K
Kexin Zhao 已提交
222

223
// Conversion opertors
K
Kexin Zhao 已提交
224
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
225
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
226 227 228 229 230 231 232 233 234 235 236
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
237

K
Kexin Zhao 已提交
238
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
239 240 241 242 243
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
244 245 246
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
247 248 249
  }
#endif

K
Kexin Zhao 已提交
250 251 252 253 254
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

255
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
256 257
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
258

K
Kexin Zhao 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
282 283
  }

K
Kexin Zhao 已提交
284 285 286
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
287
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
288 289
  }

K
Kexin Zhao 已提交
290
  HOSTDEVICE inline explicit operator uint8_t() const {
291
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
292 293
  }

K
Kexin Zhao 已提交
294
  HOSTDEVICE inline explicit operator int16_t() const {
295
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
296 297
  }

K
Kexin Zhao 已提交
298
  HOSTDEVICE inline explicit operator uint16_t() const {
299
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
300 301
  }

K
Kexin Zhao 已提交
302
  HOSTDEVICE inline explicit operator int32_t() const {
303
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
304 305
  }

K
Kexin Zhao 已提交
306
  HOSTDEVICE inline explicit operator uint32_t() const {
307
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
308 309
  }

K
Kexin Zhao 已提交
310
  HOSTDEVICE inline explicit operator int64_t() const {
311
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
312 313
  }

K
Kexin Zhao 已提交
314
  HOSTDEVICE inline explicit operator uint64_t() const {
315
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
316 317
  }

K
Kexin Zhao 已提交
318
  HOSTDEVICE inline explicit operator double() const {
319
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
320
  }
K
Kexin Zhao 已提交
321

322
 private:
K
Kexin Zhao 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
351 352
};

K
Kexin Zhao 已提交
353 354 355 356 357
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
358 359
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
360
DEVICE inline half operator+(const half& a, const half& b) {
361
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
362
  return __hadd(a, b);
363
#else
364
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
365 366
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
367 368 369
}

DEVICE inline half operator-(const half& a, const half& b) {
370
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
371
  return __hsub(a, b);
372
#else
373
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
374 375
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
376 377 378
}

DEVICE inline half operator*(const half& a, const half& b) {
379
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
380
  return __hmul(a, b);
381
#else
382
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
383 384
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
385 386 387
}

DEVICE inline half operator/(const half& a, const half& b) {
388
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
389 390 391
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
392
#else
393
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
394 395
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
396 397
}

398 399 400 401
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
402
  float res = -static_cast<float>(float16(a));
403 404 405
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
406

407
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
408 409 410 411
  a = a + b;
  return a;
}

412
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
413 414 415 416
  a = a - b;
  return a;
}

417
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
418 419 420 421
  a = a * b;
  return a;
}

422
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
423 424 425 426 427
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
428
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
429
  return __heq(a, b);
430
#else
431
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
432
#endif
K
Kexin Zhao 已提交
433 434 435
}

DEVICE inline bool operator!=(const half& a, const half& b) {
436
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
437
  return __hne(a, b);
438
#else
439
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
440
#endif
K
Kexin Zhao 已提交
441 442 443
}

DEVICE inline bool operator<(const half& a, const half& b) {
444
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
445
  return __hlt(a, b);
446
#else
447
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
448
#endif
K
Kexin Zhao 已提交
449 450 451
}

DEVICE inline bool operator<=(const half& a, const half& b) {
452
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
453
  return __hle(a, b);
454
#else
455
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
456
#endif
K
Kexin Zhao 已提交
457 458 459
}

DEVICE inline bool operator>(const half& a, const half& b) {
460
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
461
  return __hgt(a, b);
462
#else
463
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
464
#endif
K
Kexin Zhao 已提交
465 466 467
}

DEVICE inline bool operator>=(const half& a, const half& b) {
468
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
469
  return __hge(a, b);
470
#else
471
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
472
#endif
K
Kexin Zhao 已提交
473 474
}

475
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
476

477
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
478 479 480
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
481
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
482
#else
483
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
484
#endif
485 486
}

K
Kexin Zhao 已提交
487 488
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
489
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
490
#else
491
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
492
#endif
493 494
}

K
Kexin Zhao 已提交
495 496
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
497
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
498
#else
499
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
500
#endif
501 502
}

K
Kexin Zhao 已提交
503 504 505
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
506 507 508
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
509
#else
510
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
511
#endif
512 513
}

K
Kexin Zhao 已提交
514 515
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
516
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
517 518 519 520
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
521
#endif
522 523
}

524
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
525 526 527 528
  a = a + b;
  return a;
}

529
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
530 531 532 533
  a = a - b;
  return a;
}

534
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
535 536 537 538
  a = a * b;
  return a;
}

539
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
540 541 542 543
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
544 545
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
546
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
547
#else
548
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
549
#endif
550 551
}

K
Kexin Zhao 已提交
552 553
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
554
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
555
#else
556
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
557
#endif
558 559
}

K
Kexin Zhao 已提交
560 561
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
562
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
563
#else
564
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
565
#endif
566 567
}

K
Kexin Zhao 已提交
568 569
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
570
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
571
#else
572
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
573
#endif
574 575
}

K
Kexin Zhao 已提交
576 577
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
578
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
579
#else
580
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
581
#endif
582 583
}

K
Kexin Zhao 已提交
584 585
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
586
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
587
#else
588
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
589
#endif
590 591 592 593
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
594
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
595 596 597 598 599 600 601 602
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
603
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
604 605 606 607 608 609
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

610
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
611 612 613 614 615 616 617 618
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
619
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
620 621 622 623 624 625
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

626
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
627 628 629 630 631 632 633 634
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
635
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
636 637 638 639 640 641
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

642
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
643 644 645 646 647 648 649 650
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
651
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
652 653 654 655 656
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
657

658
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
659 660 661 662 663 664 665 666 667 668 669 670 671 672
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

673
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
674 675 676 677
  a = a + b;
  return a;
}

678
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
679 680 681 682
  a = a - b;
  return a;
}

683
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
684 685 686 687
  a = a * b;
  return a;
}

688
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
689 690 691 692
  a = a / b;
  return a;
}

693
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
694 695 696 697 698 699 700 701
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
702
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
703 704 705 706 707 708
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

709
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
710

711
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
712 713 714 715 716 717 718 719
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
720
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
721 722 723 724 725 726
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

727
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
728 729 730 731 732 733 734 735
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
736
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
737 738 739 740 741 742
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

743
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
744 745 746 747 748 749 750 751
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
752
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
753 754 755 756 757 758
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

759
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
760 761 762 763 764 765 766 767
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
768
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
769 770 771 772 773 774
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
775
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
776
#else
777
inline float16 operator+(const float16& a, const float16& b) {
778
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
779 780
}

781
inline float16 operator-(const float16& a, const float16& b) {
782
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
783 784
}

785
inline float16 operator*(const float16& a, const float16& b) {
786
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
787 788
}

789
inline float16 operator/(const float16& a, const float16& b) {
790
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
791 792
}

793
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
794 795 796 797 798
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

799 800
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
801 802 803
  return a;
}

804 805
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
806 807 808
  return a;
}

809 810
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
811 812 813
  return a;
}

814 815
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
816 817 818
  return a;
}

819
inline bool operator==(const float16& a, const float16& b) {
820
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
821 822
}

823
inline bool operator!=(const float16& a, const float16& b) {
824
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
825 826
}

827
inline bool operator<(const float16& a, const float16& b) {
828
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
829 830
}

831
inline bool operator<=(const float16& a, const float16& b) {
832
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
833 834
}

835
inline bool operator>(const float16& a, const float16& b) {
836
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
837 838
}

839
inline bool operator>=(const float16& a, const float16& b) {
840
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
841
}
K
Kexin Zhao 已提交
842
#endif
K
kexinzhao 已提交
843

844 845 846 847 848 849
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

866 867 868 869 870
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
871
}  // namespace platform
K
Kexin Zhao 已提交
872
}  // namespace paddle
K
kexinzhao 已提交
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
970
}  // namespace std
971 972

namespace Eigen {
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1003 1004 1005
namespace numext {

template <>
1006
HOSTDEVICE inline bool(isnan)(const float16& a) {
1007 1008 1009 1010
  return (paddle::platform::isnan)(a);
}

template <>
1011
HOSTDEVICE inline bool(isinf)(const float16& a) {
1012 1013 1014 1015
  return (paddle::platform::isinf)(a);
}

template <>
1016
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1017 1018 1019
  return (paddle::platform::isfinite)(a);
}

1020 1021 1022
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1023 1024 1025 1026 1027
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1028 1029
}

1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1070
}  // namespace numext
1071

1072
}  // namespace Eigen