float16.h 29.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

K
Kexin Zhao 已提交
41
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
42 43 44 45 46
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
47
#include <arm_neon.h>
K
Kexin Zhao 已提交
48 49
#endif

K
Kexin Zhao 已提交
50 51 52
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
53 54
#endif

K
Kexin Zhao 已提交
55
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
56 57
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
58

D
dzhwinter 已提交
59
#if !defined(_WIN32)
K
Kexin Zhao 已提交
60
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
61
#else
P
peizhilin 已提交
62
#define PADDLE_ALIGN(x) __declspec(align(x))
D
dzhwinter 已提交
63
#endif
K
Kexin Zhao 已提交
64 65

namespace paddle {
K
kexinzhao 已提交
66
namespace platform {
K
Kexin Zhao 已提交
67

68 69 70 71 72 73 74
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/platform/hostdevice.h"
75
#include "unsupported/Eigen/CXX11/Tensor"
76 77 78 79

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
80 81 82
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
83
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
84
struct PADDLE_ALIGN(2) float16 {
85
 public:
K
Kexin Zhao 已提交
86
  uint16_t x;
K
Kexin Zhao 已提交
87

K
kexinzhao 已提交
88 89
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
90 91 92 93 94 95
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
96 97

// Constructors
K
Kexin Zhao 已提交
98
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
99
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
100
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
101
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
102 103 104 105 106 107
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
108
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
109

K
Kexin Zhao 已提交
110
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
111
  // __fp16 is a native half precision data type for arm cpu,
112
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
113 114
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
115 116 117
  }
#endif

K
Kexin Zhao 已提交
118 119 120 121
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
122

123
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
124 125 126
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
127

K
Kexin Zhao 已提交
128 129
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
130

K
Kexin Zhao 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
148

K
Kexin Zhao 已提交
149
#endif
K
Kexin Zhao 已提交
150 151
  }

K
Kexin Zhao 已提交
152
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
153

K
Kexin Zhao 已提交
154 155 156
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
157

158
// Assignment operators
K
Kexin Zhao 已提交
159
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
160
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
161
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
162
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
163 164 165 166 167 168 169
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
170
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
171 172 173 174
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
175 176 177
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
178 179 180 181
    return *this;
  }
#endif

K
Kexin Zhao 已提交
182
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
183 184 185 186
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
187 188
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
189
    return *this;
K
Kexin Zhao 已提交
190 191
  }

K
Kexin Zhao 已提交
192 193
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
194 195 196
    return *this;
  }

K
Kexin Zhao 已提交
197 198
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
199 200 201
    return *this;
  }

K
Kexin Zhao 已提交
202 203
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
204 205 206
    return *this;
  }

K
Kexin Zhao 已提交
207 208
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
209 210 211
    return *this;
  }

K
Kexin Zhao 已提交
212 213
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
214 215 216
    return *this;
  }

K
Kexin Zhao 已提交
217 218
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
219 220 221
    return *this;
  }

K
Kexin Zhao 已提交
222 223
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
224 225 226
    return *this;
  }

K
Kexin Zhao 已提交
227 228
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
229 230 231
    return *this;
  }

K
Kexin Zhao 已提交
232 233
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
234
    return *this;
K
Kexin Zhao 已提交
235
  }
K
Kexin Zhao 已提交
236

237
// Conversion opertors
K
Kexin Zhao 已提交
238
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
239
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
240 241 242 243 244 245 246 247 248 249 250
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
251

K
Kexin Zhao 已提交
252
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
253 254 255 256 257
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
258 259 260
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
261 262 263
  }
#endif

K
Kexin Zhao 已提交
264 265 266 267 268
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

269
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
270 271
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
272

K
Kexin Zhao 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
296 297
  }

K
Kexin Zhao 已提交
298 299 300
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
301
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
302 303
  }

K
Kexin Zhao 已提交
304
  HOSTDEVICE inline explicit operator uint8_t() const {
305
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
306 307
  }

K
Kexin Zhao 已提交
308
  HOSTDEVICE inline explicit operator int16_t() const {
309
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
310 311
  }

K
Kexin Zhao 已提交
312
  HOSTDEVICE inline explicit operator uint16_t() const {
313
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
314 315
  }

K
Kexin Zhao 已提交
316
  HOSTDEVICE inline explicit operator int32_t() const {
317
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
318 319
  }

K
Kexin Zhao 已提交
320
  HOSTDEVICE inline explicit operator uint32_t() const {
321
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
322 323
  }

K
Kexin Zhao 已提交
324
  HOSTDEVICE inline explicit operator int64_t() const {
325
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
326 327
  }

K
Kexin Zhao 已提交
328
  HOSTDEVICE inline explicit operator uint64_t() const {
329
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
330 331
  }

K
Kexin Zhao 已提交
332
  HOSTDEVICE inline explicit operator double() const {
333
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
334
  }
K
Kexin Zhao 已提交
335

336
 private:
K
Kexin Zhao 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
365 366
};

K
Kexin Zhao 已提交
367 368 369 370 371
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
372 373
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
374
DEVICE inline half operator+(const half& a, const half& b) {
375
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
376
  return __hadd(a, b);
377
#else
378
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
379 380
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
381 382 383
}

DEVICE inline half operator-(const half& a, const half& b) {
384
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
385
  return __hsub(a, b);
386
#else
387
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
388 389
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
390 391 392
}

DEVICE inline half operator*(const half& a, const half& b) {
393
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
394
  return __hmul(a, b);
395
#else
396
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
397 398
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
399 400 401
}

DEVICE inline half operator/(const half& a, const half& b) {
402
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
403 404 405
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
406
#else
407
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
408 409
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
410 411
}

412 413 414 415
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
416
  float res = -static_cast<float>(float16(a));
417 418 419
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
420

421
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
422 423 424 425
  a = a + b;
  return a;
}

426
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
427 428 429 430
  a = a - b;
  return a;
}

431
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
432 433 434 435
  a = a * b;
  return a;
}

436
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
437 438 439 440 441
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
442
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
443
  return __heq(a, b);
444
#else
445
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
446
#endif
K
Kexin Zhao 已提交
447 448 449
}

DEVICE inline bool operator!=(const half& a, const half& b) {
450
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
451
  return __hne(a, b);
452
#else
453
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
454
#endif
K
Kexin Zhao 已提交
455 456 457
}

DEVICE inline bool operator<(const half& a, const half& b) {
458
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
459
  return __hlt(a, b);
460
#else
461
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
462
#endif
K
Kexin Zhao 已提交
463 464 465
}

DEVICE inline bool operator<=(const half& a, const half& b) {
466
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
467
  return __hle(a, b);
468
#else
469
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
470
#endif
K
Kexin Zhao 已提交
471 472 473
}

DEVICE inline bool operator>(const half& a, const half& b) {
474
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
475
  return __hgt(a, b);
476
#else
477
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
478
#endif
K
Kexin Zhao 已提交
479 480 481
}

DEVICE inline bool operator>=(const half& a, const half& b) {
482
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
483
  return __hge(a, b);
484
#else
485
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
486
#endif
K
Kexin Zhao 已提交
487 488
}

489
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
490

491
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
492 493 494
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
495
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
496
#else
497
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
498
#endif
499 500
}

K
Kexin Zhao 已提交
501 502
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
503
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
504
#else
505
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
506
#endif
507 508
}

K
Kexin Zhao 已提交
509 510
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
511
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
512
#else
513
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
514
#endif
515 516
}

K
Kexin Zhao 已提交
517 518 519
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
520 521 522
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
523
#else
524
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
525
#endif
526 527
}

K
Kexin Zhao 已提交
528 529
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
530
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
531 532 533 534
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
535
#endif
536 537
}

538
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
539 540 541 542
  a = a + b;
  return a;
}

543
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
544 545 546 547
  a = a - b;
  return a;
}

548
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
549 550 551 552
  a = a * b;
  return a;
}

553
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
554 555 556 557
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
558 559
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
560
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
561
#else
562
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
563
#endif
564 565
}

K
Kexin Zhao 已提交
566 567
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
568
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
569
#else
570
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
571
#endif
572 573
}

K
Kexin Zhao 已提交
574 575
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
576
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
577
#else
578
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
579
#endif
580 581
}

K
Kexin Zhao 已提交
582 583
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
584
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
585
#else
586
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
587
#endif
588 589
}

K
Kexin Zhao 已提交
590 591
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
592
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
593
#else
594
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
595
#endif
596 597
}

K
Kexin Zhao 已提交
598 599
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
600
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
601
#else
602
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
603
#endif
604 605 606 607
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
608
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
609 610 611 612 613 614 615 616
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
617
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
618 619 620 621 622 623
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

624
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
625 626 627 628 629 630 631 632
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
633
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
634 635 636 637 638 639
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

640
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
641 642 643 644 645 646 647 648
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
649
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
650 651 652 653 654 655
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

656
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
657 658 659 660 661 662 663 664
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
665
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
666 667 668 669 670
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
671

672
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

687
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
688 689 690 691
  a = a + b;
  return a;
}

692
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
693 694 695 696
  a = a - b;
  return a;
}

697
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
698 699 700 701
  a = a * b;
  return a;
}

702
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
703 704 705 706
  a = a / b;
  return a;
}

707
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
708 709 710 711 712 713 714 715
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
716
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
717 718 719 720 721 722
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

723
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
724

725
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
726 727 728 729 730 731 732 733
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
734
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
735 736 737 738 739 740
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

741
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
742 743 744 745 746 747 748 749
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
750
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
751 752 753 754 755 756
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

757
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
758 759 760 761 762 763 764 765
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
766
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
767 768 769 770 771 772
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

773
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
774 775 776 777 778 779 780 781
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
782
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
783 784 785 786 787 788
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
789
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
790
#else
791
inline float16 operator+(const float16& a, const float16& b) {
792
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
793 794
}

795
inline float16 operator-(const float16& a, const float16& b) {
796
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
797 798
}

799
inline float16 operator*(const float16& a, const float16& b) {
800
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
801 802
}

803
inline float16 operator/(const float16& a, const float16& b) {
804
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
805 806
}

807
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
808 809 810 811 812
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

813 814
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
815 816 817
  return a;
}

818 819
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
820 821 822
  return a;
}

823 824
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
825 826 827
  return a;
}

828 829
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
830 831 832
  return a;
}

833
inline bool operator==(const float16& a, const float16& b) {
834
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
835 836
}

837
inline bool operator!=(const float16& a, const float16& b) {
838
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
839 840
}

841
inline bool operator<(const float16& a, const float16& b) {
842
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
843 844
}

845
inline bool operator<=(const float16& a, const float16& b) {
846
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
847 848
}

849
inline bool operator>(const float16& a, const float16& b) {
850
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
851 852
}

853
inline bool operator>=(const float16& a, const float16& b) {
854
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
855
}
K
Kexin Zhao 已提交
856
#endif
K
kexinzhao 已提交
857

858 859 860 861 862 863
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

880 881 882 883 884
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
885
}  // namespace platform
K
Kexin Zhao 已提交
886
}  // namespace paddle
K
kexinzhao 已提交
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
984
}  // namespace std
985 986

namespace Eigen {
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

1017 1018 1019
namespace numext {

template <>
1020
HOSTDEVICE inline bool(isnan)(const float16& a) {
1021 1022 1023 1024
  return (paddle::platform::isnan)(a);
}

template <>
1025
HOSTDEVICE inline bool(isinf)(const float16& a) {
1026 1027 1028 1029
  return (paddle::platform::isinf)(a);
}

template <>
1030
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1031 1032 1033
  return (paddle::platform::isfinite)(a);
}

1034 1035 1036
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1037 1038 1039 1040 1041
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1042 1043
}

1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1084
}  // namespace numext
1085

1086
}  // namespace Eigen