float16.h 29.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

K
Kexin Zhao 已提交
41
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
42 43 44 45 46
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
47
#include <arm_neon.h>
K
Kexin Zhao 已提交
48 49
#endif

K
Kexin Zhao 已提交
50 51 52
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
53 54
#endif

K
Kexin Zhao 已提交
55
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
56 57
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
58

D
dzhwinter 已提交
59
#if !defined(_WIN32)
K
Kexin Zhao 已提交
60
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
61 62 63
#else
#define PADDLE_ALIGN(x) /*do nothing*/
#endif
K
Kexin Zhao 已提交
64 65

namespace paddle {
K
kexinzhao 已提交
66
namespace platform {
K
Kexin Zhao 已提交
67

68 69 70 71 72 73
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

74 75
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
76 77
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
78
#include "unsupported/Eigen/CXX11/Tensor"
79 80 81 82

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
83 84 85
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
86
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
87
struct PADDLE_ALIGN(2) float16 {
88
 public:
K
Kexin Zhao 已提交
89
  uint16_t x;
K
Kexin Zhao 已提交
90

K
kexinzhao 已提交
91 92
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
93 94 95 96 97 98
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
99 100

// Constructors
K
Kexin Zhao 已提交
101
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
102
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
103
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
104
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
105 106 107 108 109 110
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
111
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
112

K
Kexin Zhao 已提交
113
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
114
  // __fp16 is a native half precision data type for arm cpu,
115
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
116 117
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
118 119 120
  }
#endif

K
Kexin Zhao 已提交
121 122 123 124
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
125

126
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
127 128 129
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
130

K
Kexin Zhao 已提交
131 132
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
133

K
Kexin Zhao 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
151

K
Kexin Zhao 已提交
152
#endif
K
Kexin Zhao 已提交
153 154
  }

K
Kexin Zhao 已提交
155
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
156

K
Kexin Zhao 已提交
157 158 159
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
160

161
// Assignment operators
K
Kexin Zhao 已提交
162
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
163
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
164
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
165
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
166 167 168 169 170 171 172
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
173
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
174 175 176 177
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
178 179 180
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
181 182 183 184
    return *this;
  }
#endif

K
Kexin Zhao 已提交
185
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
186 187 188 189
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
190 191
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
192
    return *this;
K
Kexin Zhao 已提交
193 194
  }

K
Kexin Zhao 已提交
195 196
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
197 198 199
    return *this;
  }

K
Kexin Zhao 已提交
200 201
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
202 203 204
    return *this;
  }

K
Kexin Zhao 已提交
205 206
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
207 208 209
    return *this;
  }

K
Kexin Zhao 已提交
210 211
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
212 213 214
    return *this;
  }

K
Kexin Zhao 已提交
215 216
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
217 218 219
    return *this;
  }

K
Kexin Zhao 已提交
220 221
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
222 223 224
    return *this;
  }

K
Kexin Zhao 已提交
225 226
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
227 228 229
    return *this;
  }

K
Kexin Zhao 已提交
230 231
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
232 233 234
    return *this;
  }

K
Kexin Zhao 已提交
235 236
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
237
    return *this;
K
Kexin Zhao 已提交
238
  }
K
Kexin Zhao 已提交
239

240
// Conversion opertors
K
Kexin Zhao 已提交
241
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
242
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
243 244 245 246 247 248 249 250 251 252 253
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
254

K
Kexin Zhao 已提交
255
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
256 257 258 259 260
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
261 262 263
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
264 265 266
  }
#endif

K
Kexin Zhao 已提交
267 268 269 270 271
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

272
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
273 274
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
275

K
Kexin Zhao 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
299 300
  }

K
Kexin Zhao 已提交
301 302 303
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
304
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
305 306
  }

K
Kexin Zhao 已提交
307
  HOSTDEVICE inline explicit operator uint8_t() const {
308
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
309 310
  }

K
Kexin Zhao 已提交
311
  HOSTDEVICE inline explicit operator int16_t() const {
312
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
313 314
  }

K
Kexin Zhao 已提交
315
  HOSTDEVICE inline explicit operator uint16_t() const {
316
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
317 318
  }

K
Kexin Zhao 已提交
319
  HOSTDEVICE inline explicit operator int32_t() const {
320
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
321 322
  }

K
Kexin Zhao 已提交
323
  HOSTDEVICE inline explicit operator uint32_t() const {
324
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
325 326
  }

K
Kexin Zhao 已提交
327
  HOSTDEVICE inline explicit operator int64_t() const {
328
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
329 330
  }

K
Kexin Zhao 已提交
331
  HOSTDEVICE inline explicit operator uint64_t() const {
332
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
333 334
  }

K
Kexin Zhao 已提交
335
  HOSTDEVICE inline explicit operator double() const {
336
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
337
  }
K
Kexin Zhao 已提交
338

339
 private:
K
Kexin Zhao 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
368 369
};

K
Kexin Zhao 已提交
370 371 372 373 374
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
375 376
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
377
DEVICE inline half operator+(const half& a, const half& b) {
378
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
379
  return __hadd(a, b);
380
#else
381
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
382 383
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
384 385 386
}

DEVICE inline half operator-(const half& a, const half& b) {
387
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
388
  return __hsub(a, b);
389
#else
390
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
391 392
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
393 394 395
}

DEVICE inline half operator*(const half& a, const half& b) {
396
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
397
  return __hmul(a, b);
398
#else
399
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
400 401
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
402 403 404
}

DEVICE inline half operator/(const half& a, const half& b) {
405
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
406 407 408
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
409
#else
410
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
411 412
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
413 414
}

415 416 417 418
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
419
  float res = -static_cast<float>(float16(a));
420 421 422
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
423

424
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
425 426 427 428
  a = a + b;
  return a;
}

429
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
430 431 432 433
  a = a - b;
  return a;
}

434
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
435 436 437 438
  a = a * b;
  return a;
}

439
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
440 441 442 443 444
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
445
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
446
  return __heq(a, b);
447
#else
448
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
449
#endif
K
Kexin Zhao 已提交
450 451 452
}

DEVICE inline bool operator!=(const half& a, const half& b) {
453
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
454
  return __hne(a, b);
455
#else
456
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
457
#endif
K
Kexin Zhao 已提交
458 459 460
}

DEVICE inline bool operator<(const half& a, const half& b) {
461
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
462
  return __hlt(a, b);
463
#else
464
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
465
#endif
K
Kexin Zhao 已提交
466 467 468
}

DEVICE inline bool operator<=(const half& a, const half& b) {
469
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
470
  return __hle(a, b);
471
#else
472
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
473
#endif
K
Kexin Zhao 已提交
474 475 476
}

DEVICE inline bool operator>(const half& a, const half& b) {
477
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
478
  return __hgt(a, b);
479
#else
480
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
481
#endif
K
Kexin Zhao 已提交
482 483 484
}

DEVICE inline bool operator>=(const half& a, const half& b) {
485
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
486
  return __hge(a, b);
487
#else
488
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
489
#endif
K
Kexin Zhao 已提交
490 491
}

492
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
493

494
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
495 496 497
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
498
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
499
#else
500
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
501
#endif
502 503
}

K
Kexin Zhao 已提交
504 505
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
506
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
507
#else
508
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
509
#endif
510 511
}

K
Kexin Zhao 已提交
512 513
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
514
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
515
#else
516
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
517
#endif
518 519
}

K
Kexin Zhao 已提交
520 521 522
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
523 524 525
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
526
#else
527
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
528
#endif
529 530
}

K
Kexin Zhao 已提交
531 532
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
533
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
534 535 536 537
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
538
#endif
539 540
}

541
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
542 543 544 545
  a = a + b;
  return a;
}

546
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
547 548 549 550
  a = a - b;
  return a;
}

551
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
552 553 554 555
  a = a * b;
  return a;
}

556
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
557 558 559 560
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
561 562
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
564
#else
565
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
566
#endif
567 568
}

K
Kexin Zhao 已提交
569 570
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
571
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
572
#else
573
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
574
#endif
575 576
}

K
Kexin Zhao 已提交
577 578
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
579
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
580
#else
581
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
582
#endif
583 584
}

K
Kexin Zhao 已提交
585 586
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
587
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
588
#else
589
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
590
#endif
591 592
}

K
Kexin Zhao 已提交
593 594
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
596
#else
597
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
598
#endif
599 600
}

K
Kexin Zhao 已提交
601 602
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
603
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
604
#else
605
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
606
#endif
607 608 609 610
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
611
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
612 613 614 615 616 617 618 619
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
620
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
621 622 623 624 625 626
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

627
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
628 629 630 631 632 633 634 635
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
636
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
637 638 639 640 641 642
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

643
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
644 645 646 647 648 649 650 651
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
652
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
653 654 655 656 657 658
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

659
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
660 661 662 663 664 665 666 667
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
668
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
669 670 671 672 673
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
674

675
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
676 677 678 679 680 681 682 683 684 685 686 687 688 689
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

690
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
691 692 693 694
  a = a + b;
  return a;
}

695
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
696 697 698 699
  a = a - b;
  return a;
}

700
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
701 702 703 704
  a = a * b;
  return a;
}

705
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
706 707 708 709
  a = a / b;
  return a;
}

710
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
711 712 713 714 715 716 717 718
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
719
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
720 721 722 723 724 725
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

726
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
727

728
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
729 730 731 732 733 734 735 736
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
737
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
738 739 740 741 742 743
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

744
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
745 746 747 748 749 750 751 752
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
753
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
754 755 756 757 758 759
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

760
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
761 762 763 764 765 766 767 768
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
769
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
770 771 772 773 774 775
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

776
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
777 778 779 780 781 782 783 784
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
785
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
786 787 788 789 790 791
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
792
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
793
#else
794
inline float16 operator+(const float16& a, const float16& b) {
795
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
796 797
}

798
inline float16 operator-(const float16& a, const float16& b) {
799
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
800 801
}

802
inline float16 operator*(const float16& a, const float16& b) {
803
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
804 805
}

806
inline float16 operator/(const float16& a, const float16& b) {
807
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
808 809
}

810
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
811 812 813 814 815
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

816 817
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
818 819 820
  return a;
}

821 822
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
823 824 825
  return a;
}

826 827
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
828 829 830
  return a;
}

831 832
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
833 834 835
  return a;
}

836
inline bool operator==(const float16& a, const float16& b) {
837
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
838 839
}

840
inline bool operator!=(const float16& a, const float16& b) {
841
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
842 843
}

844
inline bool operator<(const float16& a, const float16& b) {
845
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
846 847
}

848
inline bool operator<=(const float16& a, const float16& b) {
849
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
850 851
}

852
inline bool operator>(const float16& a, const float16& b) {
853
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
854 855
}

856
inline bool operator>=(const float16& a, const float16& b) {
857
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
858
}
K
Kexin Zhao 已提交
859
#endif
K
kexinzhao 已提交
860

861 862 863 864 865 866
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

883 884 885 886 887
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
888
}  // namespace platform
K
Kexin Zhao 已提交
889
}  // namespace paddle
K
kexinzhao 已提交
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
987
}  // namespace std
988 989

namespace Eigen {
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1020 1021 1022
namespace numext {

template <>
1023
HOSTDEVICE inline bool(isnan)(const float16& a) {
1024 1025 1026 1027
  return (paddle::platform::isnan)(a);
}

template <>
1028
HOSTDEVICE inline bool(isinf)(const float16& a) {
1029 1030 1031 1032
  return (paddle::platform::isinf)(a);
}

template <>
1033
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1034 1035 1036
  return (paddle::platform::isfinite)(a);
}

1037 1038 1039
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1040 1041 1042 1043 1044
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1045 1046
}

1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1087
}  // namespace numext
1088

1089
}  // namespace Eigen