float16.h 29.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

K
Kexin Zhao 已提交
41
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
42 43 44 45 46
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
47
#include <arm_neon.h>
K
Kexin Zhao 已提交
48 49
#endif

K
Kexin Zhao 已提交
50 51 52
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
53 54
#endif

K
Kexin Zhao 已提交
55
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
56 57
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
58

K
Kexin Zhao 已提交
59
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
60 61

namespace paddle {
K
kexinzhao 已提交
62
namespace platform {
K
Kexin Zhao 已提交
63

64 65 66 67 68 69
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

70 71
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
72 73
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
74
#include "unsupported/Eigen/CXX11/Tensor"
75 76 77 78

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
79 80 81
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
82
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
83
struct PADDLE_ALIGN(2) float16 {
84
 public:
K
Kexin Zhao 已提交
85
  uint16_t x;
K
Kexin Zhao 已提交
86

K
kexinzhao 已提交
87 88
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
89 90 91 92 93 94
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
95 96

// Constructors
K
Kexin Zhao 已提交
97
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
98
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
99
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
100
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
101 102 103 104 105 106
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
107
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
108

K
Kexin Zhao 已提交
109
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
110
  // __fp16 is a native half precision data type for arm cpu,
111
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
112 113
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
114 115 116
  }
#endif

K
Kexin Zhao 已提交
117 118 119 120
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
121

122
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
123 124 125
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
126

K
Kexin Zhao 已提交
127 128
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
129

K
Kexin Zhao 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
147

K
Kexin Zhao 已提交
148
#endif
K
Kexin Zhao 已提交
149 150
  }

K
Kexin Zhao 已提交
151
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
152

K
Kexin Zhao 已提交
153 154 155
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
156

157
// Assignment operators
K
Kexin Zhao 已提交
158
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
159
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
160
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
161
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
162 163 164 165 166 167 168
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
169
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
170 171 172 173
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
174 175 176
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
177 178 179 180
    return *this;
  }
#endif

K
Kexin Zhao 已提交
181
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
182 183 184 185
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
186 187
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
188
    return *this;
K
Kexin Zhao 已提交
189 190
  }

K
Kexin Zhao 已提交
191 192
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
193 194 195
    return *this;
  }

K
Kexin Zhao 已提交
196 197
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
198 199 200
    return *this;
  }

K
Kexin Zhao 已提交
201 202
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
203 204 205
    return *this;
  }

K
Kexin Zhao 已提交
206 207
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
208 209 210
    return *this;
  }

K
Kexin Zhao 已提交
211 212
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
213 214 215
    return *this;
  }

K
Kexin Zhao 已提交
216 217
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
218 219 220
    return *this;
  }

K
Kexin Zhao 已提交
221 222
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
223 224 225
    return *this;
  }

K
Kexin Zhao 已提交
226 227
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
228 229 230
    return *this;
  }

K
Kexin Zhao 已提交
231 232
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
233
    return *this;
K
Kexin Zhao 已提交
234
  }
K
Kexin Zhao 已提交
235

236
// Conversion opertors
K
Kexin Zhao 已提交
237
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
238
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
239 240 241 242 243 244 245 246 247 248 249
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
250

K
Kexin Zhao 已提交
251
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
252 253 254 255 256
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
257 258 259
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
260 261 262
  }
#endif

K
Kexin Zhao 已提交
263 264 265 266 267
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

268
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
269 270
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
271

K
Kexin Zhao 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
295 296
  }

K
Kexin Zhao 已提交
297 298 299
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
300
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
301 302
  }

K
Kexin Zhao 已提交
303
  HOSTDEVICE inline explicit operator uint8_t() const {
304
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
305 306
  }

K
Kexin Zhao 已提交
307
  HOSTDEVICE inline explicit operator int16_t() const {
308
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
309 310
  }

K
Kexin Zhao 已提交
311
  HOSTDEVICE inline explicit operator uint16_t() const {
312
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
313 314
  }

K
Kexin Zhao 已提交
315
  HOSTDEVICE inline explicit operator int32_t() const {
316
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
317 318
  }

K
Kexin Zhao 已提交
319
  HOSTDEVICE inline explicit operator uint32_t() const {
320
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
321 322
  }

K
Kexin Zhao 已提交
323
  HOSTDEVICE inline explicit operator int64_t() const {
324
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
325 326
  }

K
Kexin Zhao 已提交
327
  HOSTDEVICE inline explicit operator uint64_t() const {
328
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
329 330
  }

K
Kexin Zhao 已提交
331
  HOSTDEVICE inline explicit operator double() const {
332
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
333
  }
K
Kexin Zhao 已提交
334

335
 private:
K
Kexin Zhao 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
364 365
};

K
Kexin Zhao 已提交
366 367 368 369 370
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
371 372
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
373
DEVICE inline half operator+(const half& a, const half& b) {
374
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
375
  return __hadd(a, b);
376
#else
377
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
378 379
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
380 381 382
}

DEVICE inline half operator-(const half& a, const half& b) {
383
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
384
  return __hsub(a, b);
385
#else
386
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
387 388
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
389 390 391
}

DEVICE inline half operator*(const half& a, const half& b) {
392
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
393
  return __hmul(a, b);
394
#else
395
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
396 397
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
398 399 400
}

DEVICE inline half operator/(const half& a, const half& b) {
401
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
402 403 404
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
405
#else
406
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
407 408
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
409 410
}

411 412 413 414
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
415
  float res = -static_cast<float>(float16(a));
416 417 418
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
419

420
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
421 422 423 424
  a = a + b;
  return a;
}

425
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
426 427 428 429
  a = a - b;
  return a;
}

430
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
431 432 433 434
  a = a * b;
  return a;
}

435
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
436 437 438 439 440
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
441
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
442
  return __heq(a, b);
443
#else
444
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
445
#endif
K
Kexin Zhao 已提交
446 447 448
}

DEVICE inline bool operator!=(const half& a, const half& b) {
449
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
450
  return __hne(a, b);
451
#else
452
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
453
#endif
K
Kexin Zhao 已提交
454 455 456
}

DEVICE inline bool operator<(const half& a, const half& b) {
457
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
458
  return __hlt(a, b);
459
#else
460
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
461
#endif
K
Kexin Zhao 已提交
462 463 464
}

DEVICE inline bool operator<=(const half& a, const half& b) {
465
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
466
  return __hle(a, b);
467
#else
468
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
469
#endif
K
Kexin Zhao 已提交
470 471 472
}

DEVICE inline bool operator>(const half& a, const half& b) {
473
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
474
  return __hgt(a, b);
475
#else
476
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
477
#endif
K
Kexin Zhao 已提交
478 479 480
}

DEVICE inline bool operator>=(const half& a, const half& b) {
481
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
482
  return __hge(a, b);
483
#else
484
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
485
#endif
K
Kexin Zhao 已提交
486 487
}

488
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
489

490
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
491 492 493
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
494
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
495
#else
496
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
497
#endif
498 499
}

K
Kexin Zhao 已提交
500 501
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
502
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
503
#else
504
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
505
#endif
506 507
}

K
Kexin Zhao 已提交
508 509
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
510
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
511
#else
512
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
513
#endif
514 515
}

K
Kexin Zhao 已提交
516 517 518
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
519 520 521
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
522
#else
523
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
524
#endif
525 526
}

K
Kexin Zhao 已提交
527 528
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
529
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
530 531 532 533
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
534
#endif
535 536
}

537
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
538 539 540 541
  a = a + b;
  return a;
}

542
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
543 544 545 546
  a = a - b;
  return a;
}

547
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
548 549 550 551
  a = a * b;
  return a;
}

552
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
553 554 555 556
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
557 558
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
559
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
560
#else
561
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
562
#endif
563 564
}

K
Kexin Zhao 已提交
565 566
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
567
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
568
#else
569
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
570
#endif
571 572
}

K
Kexin Zhao 已提交
573 574
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
575
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
576
#else
577
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
578
#endif
579 580
}

K
Kexin Zhao 已提交
581 582
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
583
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
584
#else
585
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
586
#endif
587 588
}

K
Kexin Zhao 已提交
589 590
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
591
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
592
#else
593
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
594
#endif
595 596
}

K
Kexin Zhao 已提交
597 598
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
599
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
600
#else
601
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
602
#endif
603 604 605 606
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
607
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
608 609 610 611 612 613 614 615
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
616
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
617 618 619 620 621 622
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

623
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
624 625 626 627 628 629 630 631
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
632
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
633 634 635 636 637 638
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

639
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
640 641 642 643 644 645 646 647
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
648
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
649 650 651 652 653 654
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

655
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
656 657 658 659 660 661 662 663
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
664
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
665 666 667 668 669
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
670

671
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
672 673 674 675 676 677 678 679 680 681 682 683 684 685
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

686
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
687 688 689 690
  a = a + b;
  return a;
}

691
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
692 693 694 695
  a = a - b;
  return a;
}

696
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
697 698 699 700
  a = a * b;
  return a;
}

701
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
702 703 704 705
  a = a / b;
  return a;
}

706
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
707 708 709 710 711 712 713 714
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
715
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
716 717 718 719 720 721
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

722
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
723

724
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
725 726 727 728 729 730 731 732
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
733
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
734 735 736 737 738 739
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

740
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
741 742 743 744 745 746 747 748
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
749
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
750 751 752 753 754 755
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

756
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
757 758 759 760 761 762 763 764
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
765
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
766 767 768 769 770 771
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

772
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
773 774 775 776 777 778 779 780
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
781
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
782 783 784 785 786 787
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
788
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
789
#else
790
inline float16 operator+(const float16& a, const float16& b) {
791
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
792 793
}

794
inline float16 operator-(const float16& a, const float16& b) {
795
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
796 797
}

798
inline float16 operator*(const float16& a, const float16& b) {
799
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
800 801
}

802
inline float16 operator/(const float16& a, const float16& b) {
803
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
804 805
}

806
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
807 808 809 810 811
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

812 813
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
814 815 816
  return a;
}

817 818
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
819 820 821
  return a;
}

822 823
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
824 825 826
  return a;
}

827 828
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
829 830 831
  return a;
}

832
inline bool operator==(const float16& a, const float16& b) {
833
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
834 835
}

836
inline bool operator!=(const float16& a, const float16& b) {
837
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
838 839
}

840
inline bool operator<(const float16& a, const float16& b) {
841
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
842 843
}

844
inline bool operator<=(const float16& a, const float16& b) {
845
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
846 847
}

848
inline bool operator>(const float16& a, const float16& b) {
849
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
850 851
}

852
inline bool operator>=(const float16& a, const float16& b) {
853
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
854
}
K
Kexin Zhao 已提交
855
#endif
K
kexinzhao 已提交
856

857 858 859 860 861 862
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

879 880 881 882 883
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
884
}  // namespace platform
K
Kexin Zhao 已提交
885
}  // namespace paddle
K
kexinzhao 已提交
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
983
}  // namespace std
984 985

namespace Eigen {
986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1016 1017 1018
namespace numext {

template <>
1019
HOSTDEVICE inline bool(isnan)(const float16& a) {
1020 1021 1022 1023
  return (paddle::platform::isnan)(a);
}

template <>
1024
HOSTDEVICE inline bool(isinf)(const float16& a) {
1025 1026 1027 1028
  return (paddle::platform::isinf)(a);
}

template <>
1029
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1030 1031 1032
  return (paddle::platform::isfinite)(a);
}

1033 1034 1035 1036 1037
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
}

1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1078
}  // namespace numext
1079

1080
}  // namespace Eigen