float16.h 30.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22
#endif  // PADDLE_WITH_CUDA
Y
Y_Xuan 已提交
23 24 25 26
#ifdef PADDLE_WITH_HIP
#define CUDA_VERSION 10000
#include <hip/hip_runtime.h>
#endif
K
Kexin Zhao 已提交
27

K
Kexin Zhao 已提交
28 29 30 31 32 33 34 35 36 37 38 39
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
40
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
41 42
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
43
#endif
Y
Y_Xuan 已提交
44 45 46 47
#ifdef __HIPCC__
#define PADDLE_CUDA_FP16
#include <hip/hip_fp16.h>
#endif
K
Kexin Zhao 已提交
48

D
dzhwinter 已提交
49
#if !defined(_WIN32)
K
Kexin Zhao 已提交
50
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
51
#else
P
peizhilin 已提交
52
#define PADDLE_ALIGN(x) __declspec(align(x))
D
dzhwinter 已提交
53
#endif
K
Kexin Zhao 已提交
54

55 56
#define CUDA_ARCH_FP16_SUPPORTED(CUDA_ARCH) (CUDA_ARCH >= 600)

K
Kexin Zhao 已提交
57
namespace paddle {
K
kexinzhao 已提交
58
namespace platform {
K
Kexin Zhao 已提交
59

60 61 62 63 64 65 66
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/platform/hostdevice.h"
67
#include "unsupported/Eigen/CXX11/Tensor"
68 69 70 71

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
72 73 74
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
75
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
76
struct PADDLE_ALIGN(2) float16 {
77
 public:
K
Kexin Zhao 已提交
78
  uint16_t x;
K
Kexin Zhao 已提交
79

K
kexinzhao 已提交
80 81
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
82 83 84 85 86 87
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
88 89

// Constructors
K
Kexin Zhao 已提交
90
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
91
  HOSTDEVICE inline explicit float16(const half& h) {
Y
Y_Xuan 已提交
92
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP))
K
Kexin Zhao 已提交
93
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
94
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
95 96 97
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
Y
Y_Xuan 已提交
98
#endif
K
Kexin Zhao 已提交
99 100 101
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
102
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
103

K
Kexin Zhao 已提交
104
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
105
  // __fp16 is a native half precision data type for arm cpu,
106
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
107 108
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
109 110 111
  }
#endif

K
Kexin Zhao 已提交
112
  HOSTDEVICE inline explicit float16(float val) {
Y
Y_Xuan 已提交
113 114 115
#if ((defined(PADDLE_CUDA_FP16)) &&                       \
     ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
      (defined(__HIP_DEVICE_COMPILE__))))
K
Kexin Zhao 已提交
116 117
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
118

119
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
120 121 122
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
123

K
Kexin Zhao 已提交
124 125
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
126

K
Kexin Zhao 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
144

K
Kexin Zhao 已提交
145
#endif
K
Kexin Zhao 已提交
146 147
  }

K
Kexin Zhao 已提交
148
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
149

K
Kexin Zhao 已提交
150 151 152
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
153

154
// Assignment operators
K
Kexin Zhao 已提交
155
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
156
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
157
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
158
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
159 160 161 162 163 164 165
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
166
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
167 168 169 170
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
171 172 173
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
174 175 176 177
    return *this;
  }
#endif

K
Kexin Zhao 已提交
178
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
179 180 181 182
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
183 184
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
185
    return *this;
K
Kexin Zhao 已提交
186 187
  }

K
Kexin Zhao 已提交
188 189
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
190 191 192
    return *this;
  }

K
Kexin Zhao 已提交
193 194
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
195 196 197
    return *this;
  }

K
Kexin Zhao 已提交
198 199
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
200 201 202
    return *this;
  }

K
Kexin Zhao 已提交
203 204
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
205 206 207
    return *this;
  }

K
Kexin Zhao 已提交
208 209
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
210 211 212
    return *this;
  }

K
Kexin Zhao 已提交
213 214
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
215 216 217
    return *this;
  }

K
Kexin Zhao 已提交
218 219
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
220 221 222
    return *this;
  }

K
Kexin Zhao 已提交
223 224
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
225 226 227
    return *this;
  }

K
Kexin Zhao 已提交
228 229
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
230
    return *this;
K
Kexin Zhao 已提交
231
  }
K
Kexin Zhao 已提交
232

233
// Conversion opertors
K
Kexin Zhao 已提交
234
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
235
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
236 237 238 239 240 241 242 243 244 245 246
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
247

K
Kexin Zhao 已提交
248
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
249 250 251 252 253
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
254 255 256
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
257 258 259
  }
#endif

K
Kexin Zhao 已提交
260
  HOSTDEVICE inline explicit operator float() const {
Y
Y_Xuan 已提交
261 262 263
#if (defined(PADDLE_CUDA_FP16) &&                         \
     ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
      (defined(__HIP_DEVICE_COMPILE__))))
K
Kexin Zhao 已提交
264 265 266
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

267
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
268 269
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
270

K
Kexin Zhao 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
294 295
  }

K
Kexin Zhao 已提交
296 297 298
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
299
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
300 301
  }

K
Kexin Zhao 已提交
302
  HOSTDEVICE inline explicit operator uint8_t() const {
303
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
304 305
  }

K
Kexin Zhao 已提交
306
  HOSTDEVICE inline explicit operator int16_t() const {
307
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
308 309
  }

K
Kexin Zhao 已提交
310
  HOSTDEVICE inline explicit operator uint16_t() const {
311
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
312 313
  }

K
Kexin Zhao 已提交
314
  HOSTDEVICE inline explicit operator int32_t() const {
315
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
316 317
  }

K
Kexin Zhao 已提交
318
  HOSTDEVICE inline explicit operator uint32_t() const {
319
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
320 321
  }

K
Kexin Zhao 已提交
322
  HOSTDEVICE inline explicit operator int64_t() const {
323
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
324 325
  }

K
Kexin Zhao 已提交
326
  HOSTDEVICE inline explicit operator uint64_t() const {
327
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
328 329
  }

K
Kexin Zhao 已提交
330
  HOSTDEVICE inline explicit operator double() const {
331
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
332
  }
K
Kexin Zhao 已提交
333

334
 private:
K
Kexin Zhao 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
363 364
};

K
Kexin Zhao 已提交
365 366 367 368 369
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
Y
Y_Xuan 已提交
370
// xuan[TODO] change for rocm
371
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000
K
Kexin Zhao 已提交
372
DEVICE inline half operator+(const half& a, const half& b) {
Y
Y_Xuan 已提交
373 374
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
375
  return __hadd(a, b);
376
#else
377
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
378 379
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
380 381 382
}

DEVICE inline half operator-(const half& a, const half& b) {
Y
Y_Xuan 已提交
383 384
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
385
  return __hsub(a, b);
386
#else
387
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
388 389
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
390 391 392
}

DEVICE inline half operator*(const half& a, const half& b) {
Y
Y_Xuan 已提交
393 394
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
395
  return __hmul(a, b);
396
#else
397
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
398 399
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
400 401 402
}

DEVICE inline half operator/(const half& a, const half& b) {
Y
Y_Xuan 已提交
403 404
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
405 406 407
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
408
#else
409
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
410 411
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
412 413
}

414
DEVICE inline half operator-(const half& a) {
Y
Y_Xuan 已提交
415 416
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
417 418
  return __hneg(a);
#else
419
  float res = -static_cast<float>(float16(a));
420 421 422
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
423

424
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
425 426 427 428
  a = a + b;
  return a;
}

429
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
430 431 432 433
  a = a - b;
  return a;
}

434
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
435 436 437 438
  a = a * b;
  return a;
}

439
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
440 441 442 443 444
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
Y
Y_Xuan 已提交
445 446
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
447
  return __heq(a, b);
448
#else
449
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
450
#endif
K
Kexin Zhao 已提交
451 452 453
}

DEVICE inline bool operator!=(const half& a, const half& b) {
Y
Y_Xuan 已提交
454 455
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
456
  return __hne(a, b);
457
#else
458
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
459
#endif
K
Kexin Zhao 已提交
460 461 462
}

DEVICE inline bool operator<(const half& a, const half& b) {
Y
Y_Xuan 已提交
463 464
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
465
  return __hlt(a, b);
466
#else
467
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
468
#endif
K
Kexin Zhao 已提交
469 470 471
}

DEVICE inline bool operator<=(const half& a, const half& b) {
Y
Y_Xuan 已提交
472 473
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
474
  return __hle(a, b);
475
#else
476
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
477
#endif
K
Kexin Zhao 已提交
478 479 480
}

DEVICE inline bool operator>(const half& a, const half& b) {
Y
Y_Xuan 已提交
481 482
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
483
  return __hgt(a, b);
484
#else
485
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
486
#endif
K
Kexin Zhao 已提交
487 488 489
}

DEVICE inline bool operator>=(const half& a, const half& b) {
Y
Y_Xuan 已提交
490 491
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
492
  return __hge(a, b);
493
#else
494
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
495
#endif
K
Kexin Zhao 已提交
496 497
}

498
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
499

500
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
501 502
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
503 504
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
505
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
506
#else
507
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
508
#endif
509 510
}

K
Kexin Zhao 已提交
511
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
512 513
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
514
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
515
#else
516
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
517
#endif
518 519
}

K
Kexin Zhao 已提交
520
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
521 522
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
523
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
524
#else
525
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
526
#endif
527 528
}

K
Kexin Zhao 已提交
529
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
530 531
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300) || \
     (defined(__HIP_DEVICE_COMPILE__)))
K
Kexin Zhao 已提交
532
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
533 534 535
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
536
#else
537
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
538
#endif
539 540
}

K
Kexin Zhao 已提交
541
HOSTDEVICE inline float16 operator-(const float16& a) {
Y
Y_Xuan 已提交
542 543
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
544
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
545 546 547 548
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
549
#endif
550 551
}

552
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
553 554 555 556
  a = a + b;
  return a;
}

557
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
558 559 560 561
  a = a - b;
  return a;
}

562
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
563 564 565 566
  a = a * b;
  return a;
}

567
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
568 569 570 571
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
572
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
573 574
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
575
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
576
#else
577
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
578
#endif
579 580
}

K
Kexin Zhao 已提交
581
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
582 583
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
584
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
585
#else
586
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
587
#endif
588 589
}

K
Kexin Zhao 已提交
590
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
591 592
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
593
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
594
#else
595
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
596
#endif
597 598
}

K
Kexin Zhao 已提交
599
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
600 601
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
602
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
603
#else
604
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
605
#endif
606 607
}

K
Kexin Zhao 已提交
608
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
609 610
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
611
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
612
#else
613
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
614
#endif
615 616
}

K
Kexin Zhao 已提交
617
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
Y
Y_Xuan 已提交
618 619
#if ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     (defined(__HIP_DEVICE_COMPILE__)))
620
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
621
#else
622
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
623
#endif
624 625 626 627
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
628
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
629 630 631 632 633 634 635 636
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
637
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
638 639 640 641 642 643
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

644
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
645 646 647 648 649 650 651 652
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
653
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
654 655 656 657 658 659
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

660
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
661 662 663 664 665 666 667 668
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
669
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
670 671 672 673 674 675
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

676
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
677 678 679 680 681 682 683 684
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
685
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
686 687 688 689 690
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
691

692
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

707
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
708 709 710 711
  a = a + b;
  return a;
}

712
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
713 714 715 716
  a = a - b;
  return a;
}

717
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
718 719 720 721
  a = a * b;
  return a;
}

722
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
723 724 725 726
  a = a / b;
  return a;
}

727
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
728 729 730 731 732 733 734 735
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
736
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
737 738 739 740 741 742
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

743
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
744

745
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
746 747 748 749 750 751 752 753
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
754
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
755 756 757 758 759 760
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

761
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
762 763 764 765 766 767 768 769
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
770
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
771 772 773 774 775 776
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

777
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
778 779 780 781 782 783 784 785
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
786
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
787 788 789 790 791 792
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

793
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
794 795 796 797 798 799 800 801
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
802
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
803 804 805 806 807 808
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
809
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
810
#else
811
inline float16 operator+(const float16& a, const float16& b) {
812
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
813 814
}

815
inline float16 operator-(const float16& a, const float16& b) {
816
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
817 818
}

819
inline float16 operator*(const float16& a, const float16& b) {
820
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
821 822
}

823
inline float16 operator/(const float16& a, const float16& b) {
824
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
825 826
}

827
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
828 829 830 831 832
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

833 834
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
835 836 837
  return a;
}

838 839
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
840 841 842
  return a;
}

843 844
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
845 846 847
  return a;
}

848 849
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
850 851 852
  return a;
}

853
inline bool operator==(const float16& a, const float16& b) {
854
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
855 856
}

857
inline bool operator!=(const float16& a, const float16& b) {
858
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
859 860
}

861
inline bool operator<(const float16& a, const float16& b) {
862
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
863 864
}

865
inline bool operator<=(const float16& a, const float16& b) {
866
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
867 868
}

869
inline bool operator>(const float16& a, const float16& b) {
870
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
871 872
}

873
inline bool operator>=(const float16& a, const float16& b) {
874
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
875
}
K
Kexin Zhao 已提交
876
#endif
K
kexinzhao 已提交
877

878 879 880 881 882 883
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

884
HOSTDEVICE inline bool(isnan)(const float16& a) {
Y
Y_Xuan 已提交
885 886 887
#if (defined(PADDLE_CUDA_FP16) &&                         \
     ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
      (defined(__HIP_DEVICE_COMPILE__))))
888 889 890 891 892 893 894 895 896 897 898 899 900 901
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

902 903 904 905 906 907 908 909 910 911
HOSTDEVICE inline float16(abs)(const float16& a) {
#if (defined(PADDLE_CUDA_FP16) &&                         \
     ((defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
      (defined(__HIP_DEVICE_COMPILE__))))
  return float16(::fabs(static_cast<float>(a)));
#else
  return float16(std::abs(static_cast<float>(a)));
#endif
}

912 913 914 915 916
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
917
}  // namespace platform
K
Kexin Zhao 已提交
918
}  // namespace paddle
K
kexinzhao 已提交
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

Y
Y_Xuan 已提交
987
  HOSTDEVICE static paddle::platform::float16(min)() {
988 989
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
Y
Y_Xuan 已提交
990
  HOSTDEVICE static paddle::platform::float16 lowest() {
991 992
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
Y
Y_Xuan 已提交
993
  HOSTDEVICE static paddle::platform::float16(max)() {
994 995
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
Y
Y_Xuan 已提交
996
  HOSTDEVICE static paddle::platform::float16 epsilon() {
997 998
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
Y
Y_Xuan 已提交
999
  HOSTDEVICE static paddle::platform::float16 round_error() {
1000 1001
    return paddle::platform::float16(0.5);
  }
Y
Y_Xuan 已提交
1002
  HOSTDEVICE static paddle::platform::float16 infinity() {
1003 1004
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
Y
Y_Xuan 已提交
1005
  HOSTDEVICE static paddle::platform::float16 quiet_NaN() {
1006 1007
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
Y
Y_Xuan 已提交
1008
  HOSTDEVICE static paddle::platform::float16 signaling_NaN() {
1009 1010
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
Y
Y_Xuan 已提交
1011
  HOSTDEVICE static paddle::platform::float16 denorm_min() {
1012 1013 1014 1015
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
1016
}  // namespace std
1017 1018

namespace Eigen {
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1049 1050 1051
namespace numext {

template <>
1052
HOSTDEVICE inline bool(isnan)(const float16& a) {
1053 1054 1055 1056
  return (paddle::platform::isnan)(a);
}

template <>
1057
HOSTDEVICE inline bool(isinf)(const float16& a) {
1058 1059 1060 1061
  return (paddle::platform::isinf)(a);
}

template <>
1062
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1063 1064 1065
  return (paddle::platform::isfinite)(a);
}

1066 1067 1068
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1069 1070 1071 1072 1073
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1074 1075
}

1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1116
}  // namespace numext
1117

1118
}  // namespace Eigen