float16.h 28.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

K
Kexin Zhao 已提交
41
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
42 43 44 45 46
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
47
#include <arm_neon.h>
K
Kexin Zhao 已提交
48 49
#endif

K
Kexin Zhao 已提交
50 51 52
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
53 54
#endif

K
Kexin Zhao 已提交
55
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
56 57
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
58

K
Kexin Zhao 已提交
59
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
60 61

namespace paddle {
K
kexinzhao 已提交
62
namespace platform {
K
Kexin Zhao 已提交
63

64 65 66 67 68 69 70 71 72 73 74 75
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
76 77 78
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
79
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
80
struct PADDLE_ALIGN(2) float16 {
81
 public:
K
Kexin Zhao 已提交
82
  uint16_t x;
K
Kexin Zhao 已提交
83

K
kexinzhao 已提交
84 85
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
86 87 88 89 90 91
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
92 93

// Constructors
K
Kexin Zhao 已提交
94
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
95
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
96
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
97
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
98 99 100 101 102 103
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
104
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
105

K
Kexin Zhao 已提交
106
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
107
  // __fp16 is a native half precision data type for arm cpu,
108
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
109 110
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
111 112 113
  }
#endif

K
Kexin Zhao 已提交
114 115 116 117
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
118

119
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
120 121 122
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
123

K
Kexin Zhao 已提交
124 125
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
126

K
Kexin Zhao 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
144

K
Kexin Zhao 已提交
145
#endif
K
Kexin Zhao 已提交
146 147
  }

K
Kexin Zhao 已提交
148
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
149

K
Kexin Zhao 已提交
150 151 152
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
153

154
// Assignment operators
K
Kexin Zhao 已提交
155
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
156
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
157
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
158
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
159 160 161 162 163 164 165
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
166
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
167 168 169 170
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
171 172 173
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
174 175 176 177
    return *this;
  }
#endif

K
Kexin Zhao 已提交
178
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
179 180 181 182
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
183 184
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
185
    return *this;
K
Kexin Zhao 已提交
186 187
  }

K
Kexin Zhao 已提交
188 189
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
190 191 192
    return *this;
  }

K
Kexin Zhao 已提交
193 194
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
195 196 197
    return *this;
  }

K
Kexin Zhao 已提交
198 199
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
200 201 202
    return *this;
  }

K
Kexin Zhao 已提交
203 204
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
205 206 207
    return *this;
  }

K
Kexin Zhao 已提交
208 209
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
210 211 212
    return *this;
  }

K
Kexin Zhao 已提交
213 214
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
215 216 217
    return *this;
  }

K
Kexin Zhao 已提交
218 219
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
220 221 222
    return *this;
  }

K
Kexin Zhao 已提交
223 224
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
225 226 227
    return *this;
  }

K
Kexin Zhao 已提交
228 229
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
230
    return *this;
K
Kexin Zhao 已提交
231
  }
K
Kexin Zhao 已提交
232

233
// Conversion opertors
K
Kexin Zhao 已提交
234
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
235
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
236 237 238 239 240 241 242 243 244 245 246
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
247

K
Kexin Zhao 已提交
248
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
249 250 251 252 253
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
254 255 256
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
257 258 259
  }
#endif

K
Kexin Zhao 已提交
260 261 262 263 264
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

265
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
266 267
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
268

K
Kexin Zhao 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
292 293
  }

K
Kexin Zhao 已提交
294 295 296
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
297
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
298 299
  }

K
Kexin Zhao 已提交
300
  HOSTDEVICE inline explicit operator uint8_t() const {
301
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
302 303
  }

K
Kexin Zhao 已提交
304
  HOSTDEVICE inline explicit operator int16_t() const {
305
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
306 307
  }

K
Kexin Zhao 已提交
308
  HOSTDEVICE inline explicit operator uint16_t() const {
309
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
310 311
  }

K
Kexin Zhao 已提交
312
  HOSTDEVICE inline explicit operator int32_t() const {
313
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
314 315
  }

K
Kexin Zhao 已提交
316
  HOSTDEVICE inline explicit operator uint32_t() const {
317
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
318 319
  }

K
Kexin Zhao 已提交
320
  HOSTDEVICE inline explicit operator int64_t() const {
321
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
322 323
  }

K
Kexin Zhao 已提交
324
  HOSTDEVICE inline explicit operator uint64_t() const {
325
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
326 327
  }

K
Kexin Zhao 已提交
328
  HOSTDEVICE inline explicit operator double() const {
329
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
330
  }
K
Kexin Zhao 已提交
331

332
 private:
K
Kexin Zhao 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
361 362
};

K
Kexin Zhao 已提交
363 364 365 366 367
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
368 369
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
370
DEVICE inline half operator+(const half& a, const half& b) {
371
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
372
  return __hadd(a, b);
373
#else
374
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
375 376
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
377 378 379
}

DEVICE inline half operator-(const half& a, const half& b) {
380
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
381
  return __hsub(a, b);
382
#else
383
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
384 385
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
386 387 388
}

DEVICE inline half operator*(const half& a, const half& b) {
389
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
390
  return __hmul(a, b);
391
#else
392
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
393 394
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
395 396 397
}

DEVICE inline half operator/(const half& a, const half& b) {
398
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
399 400 401
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
402
#else
403
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
404 405
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
406 407
}

408 409 410 411
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
412
  float res = -static_cast<float>(float16(a));
413 414 415
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
416

417
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
418 419 420 421
  a = a + b;
  return a;
}

422
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
423 424 425 426
  a = a - b;
  return a;
}

427
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
428 429 430 431
  a = a * b;
  return a;
}

432
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
433 434 435 436 437
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
438
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
439
  return __heq(a, b);
440
#else
441
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
442
#endif
K
Kexin Zhao 已提交
443 444 445
}

DEVICE inline bool operator!=(const half& a, const half& b) {
446
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
447
  return __hne(a, b);
448
#else
449
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
450
#endif
K
Kexin Zhao 已提交
451 452 453
}

DEVICE inline bool operator<(const half& a, const half& b) {
454
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
455
  return __hlt(a, b);
456
#else
457
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
458
#endif
K
Kexin Zhao 已提交
459 460 461
}

DEVICE inline bool operator<=(const half& a, const half& b) {
462
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
463
  return __hle(a, b);
464
#else
465
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
466
#endif
K
Kexin Zhao 已提交
467 468 469
}

DEVICE inline bool operator>(const half& a, const half& b) {
470
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
471
  return __hgt(a, b);
472
#else
473
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
474
#endif
K
Kexin Zhao 已提交
475 476 477
}

DEVICE inline bool operator>=(const half& a, const half& b) {
478
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
479
  return __hge(a, b);
480
#else
481
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
482
#endif
K
Kexin Zhao 已提交
483 484
}

485
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
486

487
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
488 489 490
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
491
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
492
#else
493
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
494
#endif
495 496
}

K
Kexin Zhao 已提交
497 498
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
499
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
500
#else
501
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
502
#endif
503 504
}

K
Kexin Zhao 已提交
505 506
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
507
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
508
#else
509
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
510
#endif
511 512
}

K
Kexin Zhao 已提交
513 514 515
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
516 517 518
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
519
#else
520
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
521
#endif
522 523
}

K
Kexin Zhao 已提交
524 525
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
526
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
527 528 529 530
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
531
#endif
532 533
}

534
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
535 536 537 538
  a = a + b;
  return a;
}

539
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
540 541 542 543
  a = a - b;
  return a;
}

544
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
545 546 547 548
  a = a * b;
  return a;
}

549
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
550 551 552 553
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
554 555
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
556
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
557
#else
558
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
559
#endif
560 561
}

K
Kexin Zhao 已提交
562 563
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
564
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
565
#else
566
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
567
#endif
568 569
}

K
Kexin Zhao 已提交
570 571
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
572
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
573
#else
574
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
575
#endif
576 577
}

K
Kexin Zhao 已提交
578 579
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
580
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
581
#else
582
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
583
#endif
584 585
}

K
Kexin Zhao 已提交
586 587
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
588
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
589
#else
590
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
591
#endif
592 593
}

K
Kexin Zhao 已提交
594 595
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
596
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
597
#else
598
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
599
#endif
600 601 602 603
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
604
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
605 606 607 608 609 610 611 612
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
613
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
614 615 616 617 618 619
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

620
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
621 622 623 624 625 626 627 628
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
629
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
630 631 632 633 634 635
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

636
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
637 638 639 640 641 642 643 644
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
645
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
646 647 648 649 650 651
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

652
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
653 654 655 656 657 658 659 660
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
661
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
662 663 664 665 666
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
667

668
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
669 670 671 672 673 674 675 676 677 678 679 680 681 682
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

683
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
684 685 686 687
  a = a + b;
  return a;
}

688
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
689 690 691 692
  a = a - b;
  return a;
}

693
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
694 695 696 697
  a = a * b;
  return a;
}

698
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
699 700 701 702
  a = a / b;
  return a;
}

703
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
704 705 706 707 708 709 710 711
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
712
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
713 714 715 716 717 718
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

719
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
720

721
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
722 723 724 725 726 727 728 729
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
730
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
731 732 733 734 735 736
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

737
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
738 739 740 741 742 743 744 745
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
746
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
747 748 749 750 751 752
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

753
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
754 755 756 757 758 759 760 761
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
762
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
763 764 765 766 767 768
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

769
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
770 771 772 773 774 775 776 777
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
778
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
779 780 781 782 783 784
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
785
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
786
#else
787
inline float16 operator+(const float16& a, const float16& b) {
788
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
789 790
}

791
inline float16 operator-(const float16& a, const float16& b) {
792
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
793 794
}

795
inline float16 operator*(const float16& a, const float16& b) {
796
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
797 798
}

799
inline float16 operator/(const float16& a, const float16& b) {
800
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
801 802
}

803
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
804 805 806 807 808
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

809 810
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
811 812 813
  return a;
}

814 815
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
816 817 818
  return a;
}

819 820
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
821 822 823
  return a;
}

824 825
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
826 827 828
  return a;
}

829
inline bool operator==(const float16& a, const float16& b) {
830
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
831 832
}

833
inline bool operator!=(const float16& a, const float16& b) {
834
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
835 836
}

837
inline bool operator<(const float16& a, const float16& b) {
838
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
839 840
}

841
inline bool operator<=(const float16& a, const float16& b) {
842
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
843 844
}

845
inline bool operator>(const float16& a, const float16& b) {
846
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
847 848
}

849
inline bool operator>=(const float16& a, const float16& b) {
850
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
851
}
K
Kexin Zhao 已提交
852
#endif
K
kexinzhao 已提交
853

854 855 856 857 858 859
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

K
kexinzhao 已提交
876
}  // namespace platform
K
Kexin Zhao 已提交
877
}  // namespace paddle
K
kexinzhao 已提交
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
951
}  // namespace std
952 953

namespace Eigen {
954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

984 985 986
namespace numext {

template <>
987
HOSTDEVICE inline bool(isnan)(const float16& a) {
988 989 990 991
  return (paddle::platform::isnan)(a);
}

template <>
992
HOSTDEVICE inline bool(isinf)(const float16& a) {
993 994 995 996
  return (paddle::platform::isinf)(a);
}

template <>
997
HOSTDEVICE inline bool(isfinite)(const float16& a) {
998 999 1000
  return (paddle::platform::isfinite)(a);
}

1001 1002 1003 1004 1005
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
}

1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1046
}  // namespace numext
1047

1048
}  // namespace Eigen