float16.h 32.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22
#endif  // PADDLE_WITH_CUDA
23

Y
Y_Xuan 已提交
24 25 26
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
K
Kexin Zhao 已提交
27

K
Kexin Zhao 已提交
28 29 30 31 32 33 34 35 36 37 38 39
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
40
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
41 42
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
43
#endif
44

Y
Y_Xuan 已提交
45 46 47 48
#ifdef __HIPCC__
#define PADDLE_CUDA_FP16
#include <hip/hip_fp16.h>
#endif
K
Kexin Zhao 已提交
49

D
dzhwinter 已提交
50
#if !defined(_WIN32)
K
Kexin Zhao 已提交
51
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
52
#else
P
peizhilin 已提交
53
#define PADDLE_ALIGN(x) __declspec(align(x))
D
dzhwinter 已提交
54
#endif
K
Kexin Zhao 已提交
55

56 57
#define CUDA_ARCH_FP16_SUPPORTED(CUDA_ARCH) (CUDA_ARCH >= 600)

K
Kexin Zhao 已提交
58
namespace paddle {
K
kexinzhao 已提交
59
namespace platform {
K
Kexin Zhao 已提交
60

61 62 63 64 65 66 67
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/platform/hostdevice.h"
68
#include "unsupported/Eigen/CXX11/Tensor"
69 70 71 72

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
73 74 75
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
76
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
77
struct PADDLE_ALIGN(2) float16 {
78
 public:
K
Kexin Zhao 已提交
79
  uint16_t x;
K
Kexin Zhao 已提交
80

K
kexinzhao 已提交
81 82
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
83 84 85 86 87 88
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
89 90

// Constructors
K
Kexin Zhao 已提交
91
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
92
  HOSTDEVICE inline explicit float16(const half& h) {
Y
Y_Xuan 已提交
93
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP))
94
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000
Y
Yu Yang 已提交
95
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
96 97 98
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
Y
Y_Xuan 已提交
99
#endif
K
Kexin Zhao 已提交
100 101 102
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
103
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
104

K
Kexin Zhao 已提交
105
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
106
  // __fp16 is a native half precision data type for arm cpu,
107
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
108 109
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
110 111 112
  }
#endif

K
Kexin Zhao 已提交
113
  HOSTDEVICE inline explicit float16(float val) {
114 115
#if defined(PADDLE_CUDA_FP16) && \
    (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300))
K
Kexin Zhao 已提交
116 117
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
118

119
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
120 121 122
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
123

K
Kexin Zhao 已提交
124 125
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
126

K
Kexin Zhao 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
144

K
Kexin Zhao 已提交
145
#endif
K
Kexin Zhao 已提交
146 147
  }

K
Kexin Zhao 已提交
148
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
149

K
Kexin Zhao 已提交
150 151 152
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
153

154
// Assignment operators
K
Kexin Zhao 已提交
155
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
156
  HOSTDEVICE inline float16& operator=(const half& rhs) {
157
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000
Y
Yu Yang 已提交
158
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
159 160 161 162 163 164 165
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
166
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
167 168 169 170
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
171 172 173
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
174 175 176 177
    return *this;
  }
#endif

K
Kexin Zhao 已提交
178
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
179 180 181 182
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
183 184
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
185
    return *this;
K
Kexin Zhao 已提交
186 187
  }

K
Kexin Zhao 已提交
188 189
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
190 191 192
    return *this;
  }

K
Kexin Zhao 已提交
193 194
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
195 196 197
    return *this;
  }

K
Kexin Zhao 已提交
198 199
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
200 201 202
    return *this;
  }

K
Kexin Zhao 已提交
203 204
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
205 206 207
    return *this;
  }

K
Kexin Zhao 已提交
208 209
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
210 211 212
    return *this;
  }

K
Kexin Zhao 已提交
213 214
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
215 216 217
    return *this;
  }

K
Kexin Zhao 已提交
218 219
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
220 221 222
    return *this;
  }

K
Kexin Zhao 已提交
223 224
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
225 226 227
    return *this;
  }

K
Kexin Zhao 已提交
228 229
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
230
    return *this;
K
Kexin Zhao 已提交
231
  }
K
Kexin Zhao 已提交
232

233
// Conversion opertors
K
Kexin Zhao 已提交
234
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
235
  HOSTDEVICE inline explicit operator half() const {
236
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000
K
Kexin Zhao 已提交
237 238 239 240 241 242 243 244 245 246
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
247

K
Kexin Zhao 已提交
248
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
249 250 251 252 253
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
254 255 256
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
257 258 259
  }
#endif

K
Kexin Zhao 已提交
260
  HOSTDEVICE inline explicit operator float() const {
261 262
#if defined(PADDLE_CUDA_FP16) && \
    (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300))
K
Kexin Zhao 已提交
263 264 265
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

266
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
267 268
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
269

K
Kexin Zhao 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
293 294
  }

K
Kexin Zhao 已提交
295 296 297
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
298
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
299 300
  }

K
Kexin Zhao 已提交
301
  HOSTDEVICE inline explicit operator uint8_t() const {
302
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
303 304
  }

K
Kexin Zhao 已提交
305
  HOSTDEVICE inline explicit operator int16_t() const {
306
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
307 308
  }

K
Kexin Zhao 已提交
309
  HOSTDEVICE inline explicit operator uint16_t() const {
310
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
311 312
  }

K
Kexin Zhao 已提交
313
  HOSTDEVICE inline explicit operator int32_t() const {
314
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
315 316
  }

K
Kexin Zhao 已提交
317
  HOSTDEVICE inline explicit operator uint32_t() const {
318
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
319 320
  }

K
Kexin Zhao 已提交
321
  HOSTDEVICE inline explicit operator int64_t() const {
322
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
323 324
  }

K
Kexin Zhao 已提交
325
  HOSTDEVICE inline explicit operator uint64_t() const {
326
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
327 328
  }

K
Kexin Zhao 已提交
329
  HOSTDEVICE inline explicit operator double() const {
330
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
331
  }
K
Kexin Zhao 已提交
332

333
 private:
K
Kexin Zhao 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
362 363
};

K
Kexin Zhao 已提交
364 365 366 367 368
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
Y
Y_Xuan 已提交
369
// xuan[TODO] change for rocm
370
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000
K
Kexin Zhao 已提交
371
DEVICE inline half operator+(const half& a, const half& b) {
372
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
373
  return __hadd(a, b);
374
#else
375
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
376 377
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
378 379 380
}

DEVICE inline half operator-(const half& a, const half& b) {
381
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
382
  return __hsub(a, b);
383
#else
384
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
385 386
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
387 388 389
}

DEVICE inline half operator*(const half& a, const half& b) {
390
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
391
  return __hmul(a, b);
392
#else
393
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
394 395
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
396 397 398
}

DEVICE inline half operator/(const half& a, const half& b) {
399
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
400 401 402
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
403
#else
404
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
405 406
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
407 408
}

409
DEVICE inline half operator-(const half& a) {
410
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
411 412
  return __hneg(a);
#else
413
  float res = -static_cast<float>(float16(a));
414 415 416
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
417

418
#ifndef PADDLE_WITH_HIP  // not defined __HIP_NO_HALF_OPERATORS__
419
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
420 421 422 423
  a = a + b;
  return a;
}

424
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
425 426 427 428
  a = a - b;
  return a;
}

429
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
430 431 432 433
  a = a * b;
  return a;
}

434
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
435 436 437
  a = a / b;
  return a;
}
438
#endif
K
Kexin Zhao 已提交
439 440

DEVICE inline bool operator==(const half& a, const half& b) {
441
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
442
  return __heq(a, b);
443
#else
444
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
445
#endif
K
Kexin Zhao 已提交
446 447 448
}

DEVICE inline bool operator!=(const half& a, const half& b) {
449
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
450
  return __hne(a, b);
451
#else
452
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
453
#endif
K
Kexin Zhao 已提交
454 455 456
}

DEVICE inline bool operator<(const half& a, const half& b) {
457
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
458
  return __hlt(a, b);
459
#else
460
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
461
#endif
K
Kexin Zhao 已提交
462 463 464
}

DEVICE inline bool operator<=(const half& a, const half& b) {
465
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
466
  return __hle(a, b);
467
#else
468
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
469
#endif
K
Kexin Zhao 已提交
470 471 472
}

DEVICE inline bool operator>(const half& a, const half& b) {
473
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
474
  return __hgt(a, b);
475
#else
476
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
477
#endif
K
Kexin Zhao 已提交
478 479 480
}

DEVICE inline bool operator>=(const half& a, const half& b) {
481
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
482
  return __hge(a, b);
483
#else
484
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
485
#endif
K
Kexin Zhao 已提交
486 487
}

488
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
489

490
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
491
#if defined(PADDLE_CUDA_FP16)
492 493 494 495 496 497 498 499 500 501 502

// HIPCC has compile error if call __device__ function __hadd in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE inline float16 operator+(const float16& a, const float16& b) {
  return float16(__hadd(half(a), half(b)));
}
HOST inline float16 operator+(const float16& a, const float16& b) {
  return float16(static_cast<float>(a) + static_cast<float>(b));
}
#else
K
Kexin Zhao 已提交
503
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
504
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
505
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
506
#else
507
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
508
#endif
509
}
510
#endif
511

512 513 514 515 516 517 518 519 520 521
// HIPCC has compile error if call __device__ function __hsub in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE inline float16 operator-(const float16& a, const float16& b) {
  return float16(__hsub(half(a), half(b)));
}
HOST inline float16 operator-(const float16& a, const float16& b) {
  return float16(static_cast<float>(a) - static_cast<float>(b));
}
#else
K
Kexin Zhao 已提交
522
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
523
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
524
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
525
#else
526
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
527
#endif
528
}
529
#endif
530

531 532 533 534 535 536 537 538 539 540
// HIPCC has compile error if call __device__ function __hmul in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE inline float16 operator*(const float16& a, const float16& b) {
  return float16(__hmul(half(a), half(b)));
}
HOST inline float16 operator*(const float16& a, const float16& b) {
  return float16(static_cast<float>(a) * static_cast<float>(b));
}
#else
K
Kexin Zhao 已提交
541
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
542
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
543
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
544
#else
545
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
546
#endif
547
}
548
#endif
549

K
Kexin Zhao 已提交
550
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
551
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
K
Kexin Zhao 已提交
552
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
553 554 555
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
556
#else
557
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
558
#endif
559 560
}

561 562 563 564 565 566 567 568 569 570 571 572
// HIPCC has compile error if call __device__ function __hneg in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE inline float16 operator-(const float16& a) {
  return float16(__hneg(half(a)));
}
HOST inline float16 operator-(const float16& a) {
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}
#else
K
Kexin Zhao 已提交
573
HOSTDEVICE inline float16 operator-(const float16& a) {
574
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
575
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
576 577 578 579
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
580
#endif
581
}
582
#endif
583

584
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
585 586 587 588
  a = a + b;
  return a;
}

589
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
590 591 592 593
  a = a - b;
  return a;
}

594
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
595 596 597 598
  a = a * b;
  return a;
}

599
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
600 601 602 603
  a = a / b;
  return a;
}

604 605 606 607 608 609 610 611 612 613
// HIPCC has compile error if call __device__ function __heq in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE inline bool operator==(const float16& a, const float16& b) {
  return __heq(half(a), half(b));
}
HOST inline bool operator==(const float16& a, const float16& b) {
  return static_cast<float>(a) == static_cast<float>(b);
}
#else  // CUDA
K
Kexin Zhao 已提交
614
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
615
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
616
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
617
#else
618
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
619
#endif
620
}
621
#endif
622

K
Kexin Zhao 已提交
623
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
624
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
625
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
626
#else
627
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
628
#endif
629 630
}

K
Kexin Zhao 已提交
631
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
632
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
633
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
634
#else
635
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
636
#endif
637 638
}

K
Kexin Zhao 已提交
639
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
640
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
641
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
642
#else
643
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
644
#endif
645 646
}

K
Kexin Zhao 已提交
647
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
648
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
649
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
650
#else
651
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
652
#endif
653 654
}

K
Kexin Zhao 已提交
655
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
656
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
657
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
658
#else
659
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
660
#endif
661 662 663 664
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
665
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
666 667 668 669 670 671 672 673
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
674
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
675 676 677 678 679 680
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

681
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
682 683 684 685 686 687 688 689
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
690
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
691 692 693 694 695 696
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

697
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
698 699 700 701 702 703 704 705
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
706
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
707 708 709 710 711 712
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

713
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
714 715 716 717 718 719 720 721
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
722
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
723 724 725 726 727
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
728

729
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
730 731 732 733 734 735 736 737 738 739 740 741 742 743
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

744
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
745 746 747 748
  a = a + b;
  return a;
}

749
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
750 751 752 753
  a = a - b;
  return a;
}

754
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
755 756 757 758
  a = a * b;
  return a;
}

759
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
760 761 762 763
  a = a / b;
  return a;
}

764
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
765 766 767 768 769 770 771 772
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
773
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
774 775 776 777 778 779
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

780
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
781

782
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
783 784 785 786 787 788 789 790
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
791
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
792 793 794 795 796 797
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

798
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
799 800 801 802 803 804 805 806
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
807
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
808 809 810 811 812 813
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

814
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
815 816 817 818 819 820 821 822
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
823
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
824 825 826 827 828 829
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

830
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
831 832 833 834 835 836 837 838
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
839
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
840 841 842 843 844 845
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
846
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
847
#else
848
inline float16 operator+(const float16& a, const float16& b) {
849
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
850 851
}

852
inline float16 operator-(const float16& a, const float16& b) {
853
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
854 855
}

856
inline float16 operator*(const float16& a, const float16& b) {
857
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
858 859
}

860
inline float16 operator/(const float16& a, const float16& b) {
861
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
862 863
}

864
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
865 866 867 868 869
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

870 871
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
872 873 874
  return a;
}

875 876
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
877 878 879
  return a;
}

880 881
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
882 883 884
  return a;
}

885 886
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
887 888 889
  return a;
}

890
inline bool operator==(const float16& a, const float16& b) {
891
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
892 893
}

894
inline bool operator!=(const float16& a, const float16& b) {
895
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
896 897
}

898
inline bool operator<(const float16& a, const float16& b) {
899
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
900 901
}

902
inline bool operator<=(const float16& a, const float16& b) {
903
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
904 905
}

906
inline bool operator>(const float16& a, const float16& b) {
907
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
908 909
}

910
inline bool operator>=(const float16& a, const float16& b) {
911
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
912
}
K
Kexin Zhao 已提交
913
#endif
K
kexinzhao 已提交
914

915 916 917 918 919 920
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

921 922 923 924 925 926
// HIPCC has compile error if call __device__ function __hisnan in __host__
// __device__ function
#if defined(PADDLE_CUDA_FP16) && defined(__HIPCC__)
DEVICE inline bool(isnan)(const float16& a) { return __hisnan(half(a)); }
HOST inline bool(isnan)(const float16& a) { return (a.x & 0x7fff) > 0x7c00; }
#else
927
HOSTDEVICE inline bool(isnan)(const float16& a) {
928
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
929 930 931 932 933
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}
934
#endif
935 936 937 938 939 940 941 942 943

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

944
HOSTDEVICE inline float16(abs)(const float16& a) {
945 946
#if defined(PADDLE_CUDA_FP16) && \
    (defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530))
947 948 949 950 951 952
  return float16(::fabs(static_cast<float>(a)));
#else
  return float16(std::abs(static_cast<float>(a)));
#endif
}

953 954 955 956 957
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
958
}  // namespace platform
K
Kexin Zhao 已提交
959
}  // namespace paddle
K
kexinzhao 已提交
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

Y
Y_Xuan 已提交
1028
  HOSTDEVICE static paddle::platform::float16(min)() {
1029 1030
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
Y
Y_Xuan 已提交
1031
  HOSTDEVICE static paddle::platform::float16 lowest() {
1032 1033
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
Y
Y_Xuan 已提交
1034
  HOSTDEVICE static paddle::platform::float16(max)() {
1035 1036
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
Y
Y_Xuan 已提交
1037
  HOSTDEVICE static paddle::platform::float16 epsilon() {
1038 1039
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
Y
Y_Xuan 已提交
1040
  HOSTDEVICE static paddle::platform::float16 round_error() {
1041 1042
    return paddle::platform::float16(0.5);
  }
Y
Y_Xuan 已提交
1043
  HOSTDEVICE static paddle::platform::float16 infinity() {
1044 1045
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
Y
Y_Xuan 已提交
1046
  HOSTDEVICE static paddle::platform::float16 quiet_NaN() {
1047 1048
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
Y
Y_Xuan 已提交
1049
  HOSTDEVICE static paddle::platform::float16 signaling_NaN() {
1050 1051
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
Y
Y_Xuan 已提交
1052
  HOSTDEVICE static paddle::platform::float16 denorm_min() {
1053 1054 1055 1056
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

1057 1058 1059 1060 1061
HOSTDEVICE inline paddle::platform::float16 abs(
    const paddle::platform::float16& a) {
  return paddle::platform::abs(a);
}

K
kexinzhao 已提交
1062
}  // namespace std
1063 1064

namespace Eigen {
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1095 1096 1097
namespace numext {

template <>
1098
HOSTDEVICE inline bool(isnan)(const float16& a) {
1099 1100 1101 1102
  return (paddle::platform::isnan)(a);
}

template <>
1103
HOSTDEVICE inline bool(isinf)(const float16& a) {
1104 1105 1106 1107
  return (paddle::platform::isinf)(a);
}

template <>
1108
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1109 1110 1111
  return (paddle::platform::isfinite)(a);
}

1112 1113 1114
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1115 1116 1117 1118 1119
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1120 1121
}

1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1162
}  // namespace numext
1163

1164
}  // namespace Eigen