float16.h 28.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
18
#include <limits>
K
Kexin Zhao 已提交
19

K
Kexin Zhao 已提交
20
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
21
#include <cuda.h>
K
Kexin Zhao 已提交
22 23
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
24 25 26 27 28 29 30 31 32 33 34 35
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
36
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
37 38
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
39 40
#endif

D
dzhwinter 已提交
41
#if !defined(_WIN32)
K
Kexin Zhao 已提交
42
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
D
dzhwinter 已提交
43
#else
P
peizhilin 已提交
44
#define PADDLE_ALIGN(x) __declspec(align(x))
D
dzhwinter 已提交
45
#endif
K
Kexin Zhao 已提交
46 47

namespace paddle {
K
kexinzhao 已提交
48
namespace platform {
K
Kexin Zhao 已提交
49

50 51 52 53 54 55 56
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/platform/hostdevice.h"
57
#include "unsupported/Eigen/CXX11/Tensor"
58 59 60 61

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
62 63 64
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
65
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
66
struct PADDLE_ALIGN(2) float16 {
67
 public:
K
Kexin Zhao 已提交
68
  uint16_t x;
K
Kexin Zhao 已提交
69

K
kexinzhao 已提交
70 71
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
72 73 74 75 76 77
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
78 79

// Constructors
K
Kexin Zhao 已提交
80
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
81
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
82
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
83
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
84 85 86 87 88 89
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
90
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
91

K
Kexin Zhao 已提交
92
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
93
  // __fp16 is a native half precision data type for arm cpu,
94
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
95 96
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
97 98 99
  }
#endif

K
Kexin Zhao 已提交
100 101 102 103
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
104

105
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
106 107 108
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
109

K
Kexin Zhao 已提交
110 111
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
112

K
Kexin Zhao 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
130

K
Kexin Zhao 已提交
131
#endif
K
Kexin Zhao 已提交
132 133
  }

K
Kexin Zhao 已提交
134
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
135

K
Kexin Zhao 已提交
136 137 138
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
139

140
// Assignment operators
K
Kexin Zhao 已提交
141
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
142
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
143
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
144
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
145 146 147 148 149 150 151
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
152
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
153 154 155 156
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
157 158 159
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
160 161 162 163
    return *this;
  }
#endif

K
Kexin Zhao 已提交
164
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
165 166 167 168
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
169 170
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
171
    return *this;
K
Kexin Zhao 已提交
172 173
  }

K
Kexin Zhao 已提交
174 175
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
176 177 178
    return *this;
  }

K
Kexin Zhao 已提交
179 180
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
181 182 183
    return *this;
  }

K
Kexin Zhao 已提交
184 185
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
186 187 188
    return *this;
  }

K
Kexin Zhao 已提交
189 190
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
191 192 193
    return *this;
  }

K
Kexin Zhao 已提交
194 195
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
196 197 198
    return *this;
  }

K
Kexin Zhao 已提交
199 200
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
201 202 203
    return *this;
  }

K
Kexin Zhao 已提交
204 205
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
206 207 208
    return *this;
  }

K
Kexin Zhao 已提交
209 210
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
211 212 213
    return *this;
  }

K
Kexin Zhao 已提交
214 215
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
216
    return *this;
K
Kexin Zhao 已提交
217
  }
K
Kexin Zhao 已提交
218

219
// Conversion opertors
K
Kexin Zhao 已提交
220
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
221
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
222 223 224 225 226 227 228 229 230 231 232
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
233

K
Kexin Zhao 已提交
234
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
235 236 237 238 239
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
240 241 242
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
243 244 245
  }
#endif

K
Kexin Zhao 已提交
246 247 248 249 250
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

251
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
252 253
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
254

K
Kexin Zhao 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
278 279
  }

K
Kexin Zhao 已提交
280 281 282
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
283
    return static_cast<int8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
284 285
  }

K
Kexin Zhao 已提交
286
  HOSTDEVICE inline explicit operator uint8_t() const {
287
    return static_cast<uint8_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
288 289
  }

K
Kexin Zhao 已提交
290
  HOSTDEVICE inline explicit operator int16_t() const {
291
    return static_cast<int16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
292 293
  }

K
Kexin Zhao 已提交
294
  HOSTDEVICE inline explicit operator uint16_t() const {
295
    return static_cast<uint16_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
296 297
  }

K
Kexin Zhao 已提交
298
  HOSTDEVICE inline explicit operator int32_t() const {
299
    return static_cast<int32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
300 301
  }

K
Kexin Zhao 已提交
302
  HOSTDEVICE inline explicit operator uint32_t() const {
303
    return static_cast<uint32_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
304 305
  }

K
Kexin Zhao 已提交
306
  HOSTDEVICE inline explicit operator int64_t() const {
307
    return static_cast<int64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
308 309
  }

K
Kexin Zhao 已提交
310
  HOSTDEVICE inline explicit operator uint64_t() const {
311
    return static_cast<uint64_t>(static_cast<float>(*this));
K
Kexin Zhao 已提交
312 313
  }

K
Kexin Zhao 已提交
314
  HOSTDEVICE inline explicit operator double() const {
315
    return static_cast<double>(static_cast<float>(*this));
K
Kexin Zhao 已提交
316
  }
K
Kexin Zhao 已提交
317

318
 private:
K
Kexin Zhao 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
347 348
};

K
Kexin Zhao 已提交
349 350 351 352 353
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
354 355
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
356
DEVICE inline half operator+(const half& a, const half& b) {
357
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
358
  return __hadd(a, b);
359
#else
360
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
361 362
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
363 364 365
}

DEVICE inline half operator-(const half& a, const half& b) {
366
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
367
  return __hsub(a, b);
368
#else
369
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
370 371
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
372 373 374
}

DEVICE inline half operator*(const half& a, const half& b) {
375
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
376
  return __hmul(a, b);
377
#else
378
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
379 380
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
381 382 383
}

DEVICE inline half operator/(const half& a, const half& b) {
384
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
385 386 387
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
388
#else
389
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
390 391
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
392 393
}

394 395 396 397
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
398
  float res = -static_cast<float>(float16(a));
399 400 401
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
402

403
DEVICE inline half& operator+=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
404 405 406 407
  a = a + b;
  return a;
}

408
DEVICE inline half& operator-=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
409 410 411 412
  a = a - b;
  return a;
}

413
DEVICE inline half& operator*=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
414 415 416 417
  a = a * b;
  return a;
}

418
DEVICE inline half& operator/=(half& a, const half& b) {  // NOLINT
K
Kexin Zhao 已提交
419 420 421 422 423
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
424
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
425
  return __heq(a, b);
426
#else
427
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
428
#endif
K
Kexin Zhao 已提交
429 430 431
}

DEVICE inline bool operator!=(const half& a, const half& b) {
432
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
433
  return __hne(a, b);
434
#else
435
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
436
#endif
K
Kexin Zhao 已提交
437 438 439
}

DEVICE inline bool operator<(const half& a, const half& b) {
440
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
441
  return __hlt(a, b);
442
#else
443
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
444
#endif
K
Kexin Zhao 已提交
445 446 447
}

DEVICE inline bool operator<=(const half& a, const half& b) {
448
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
449
  return __hle(a, b);
450
#else
451
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
452
#endif
K
Kexin Zhao 已提交
453 454 455
}

DEVICE inline bool operator>(const half& a, const half& b) {
456
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
457
  return __hgt(a, b);
458
#else
459
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
460
#endif
K
Kexin Zhao 已提交
461 462 463
}

DEVICE inline bool operator>=(const half& a, const half& b) {
464
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
465
  return __hge(a, b);
466
#else
467
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
468
#endif
K
Kexin Zhao 已提交
469 470
}

471
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
472

473
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
474 475 476
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
477
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
478
#else
479
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
480
#endif
481 482
}

K
Kexin Zhao 已提交
483 484
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
485
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
486
#else
487
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
488
#endif
489 490
}

K
Kexin Zhao 已提交
491 492
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
493
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
494
#else
495
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
496
#endif
497 498
}

K
Kexin Zhao 已提交
499 500 501
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
502 503 504
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
505
#else
506
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
507
#endif
508 509
}

K
Kexin Zhao 已提交
510 511
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
512
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
513 514 515 516
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
517
#endif
518 519
}

520
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
521 522 523 524
  a = a + b;
  return a;
}

525
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
526 527 528 529
  a = a - b;
  return a;
}

530
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
531 532 533 534
  a = a * b;
  return a;
}

535
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
536 537 538 539
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
540 541
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
542
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
543
#else
544
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
545
#endif
546 547
}

K
Kexin Zhao 已提交
548 549
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
550
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
551
#else
552
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
553
#endif
554 555
}

K
Kexin Zhao 已提交
556 557
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
558
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
559
#else
560
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
561
#endif
562 563
}

K
Kexin Zhao 已提交
564 565
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
566
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
567
#else
568
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
569
#endif
570 571
}

K
Kexin Zhao 已提交
572 573
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
574
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
575
#else
576
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
577
#endif
578 579
}

K
Kexin Zhao 已提交
580 581
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
582
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
583
#else
584
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
585
#endif
586 587 588 589
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
590
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
591 592 593 594 595 596 597 598
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
599
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
600 601 602 603 604 605
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

606
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
607 608 609 610 611 612 613 614
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
615
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
616 617 618 619 620 621
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

622
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
623 624 625 626 627 628 629 630
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
631
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
632 633 634 635 636 637
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

638
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
639 640 641 642 643 644 645 646
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
647
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
648 649 650 651 652
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
653

654
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667 668
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

669
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
670 671 672 673
  a = a + b;
  return a;
}

674
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
675 676 677 678
  a = a - b;
  return a;
}

679
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
680 681 682 683
  a = a * b;
  return a;
}

684
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
K
Kexin Zhao 已提交
685 686 687 688
  a = a / b;
  return a;
}

689
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
690 691 692 693 694 695 696 697
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
698
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
699 700 701 702 703 704
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

705
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
706

707
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
708 709 710 711 712 713 714 715
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
716
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
717 718 719 720 721 722
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

723
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
724 725 726 727 728 729 730 731
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
732
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
733 734 735 736 737 738
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

739
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
740 741 742 743 744 745 746 747
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
748
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
749 750 751 752 753 754
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

755
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
756 757 758 759 760 761 762 763
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
764
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
765 766 767 768 769 770
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
771
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
772
#else
773
inline float16 operator+(const float16& a, const float16& b) {
774
  return float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
775 776
}

777
inline float16 operator-(const float16& a, const float16& b) {
778
  return float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
779 780
}

781
inline float16 operator*(const float16& a, const float16& b) {
782
  return float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
783 784
}

785
inline float16 operator/(const float16& a, const float16& b) {
786
  return float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
787 788
}

789
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
790 791 792 793 794
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

795 796
inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) + static_cast<float>(b));
K
Kexin Zhao 已提交
797 798 799
  return a;
}

800 801
inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) - static_cast<float>(b));
K
Kexin Zhao 已提交
802 803 804
  return a;
}

805 806
inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) * static_cast<float>(b));
K
Kexin Zhao 已提交
807 808 809
  return a;
}

810 811
inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = float16(static_cast<float>(a) / static_cast<float>(b));
K
Kexin Zhao 已提交
812 813 814
  return a;
}

815
inline bool operator==(const float16& a, const float16& b) {
816
  return static_cast<float>(a) == static_cast<float>(b);
K
Kexin Zhao 已提交
817 818
}

819
inline bool operator!=(const float16& a, const float16& b) {
820
  return static_cast<float>(a) != static_cast<float>(b);
K
Kexin Zhao 已提交
821 822
}

823
inline bool operator<(const float16& a, const float16& b) {
824
  return static_cast<float>(a) < static_cast<float>(b);
K
Kexin Zhao 已提交
825 826
}

827
inline bool operator<=(const float16& a, const float16& b) {
828
  return static_cast<float>(a) <= static_cast<float>(b);
K
Kexin Zhao 已提交
829 830
}

831
inline bool operator>(const float16& a, const float16& b) {
832
  return static_cast<float>(a) > static_cast<float>(b);
K
Kexin Zhao 已提交
833 834
}

835
inline bool operator>=(const float16& a, const float16& b) {
836
  return static_cast<float>(a) >= static_cast<float>(b);
K
Kexin Zhao 已提交
837
}
K
Kexin Zhao 已提交
838
#endif
K
kexinzhao 已提交
839

840 841 842 843 844 845
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

862 863 864 865 866
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
  os << static_cast<float>(a);
  return os;
}

K
kexinzhao 已提交
867
}  // namespace platform
K
Kexin Zhao 已提交
868
}  // namespace paddle
K
kexinzhao 已提交
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
template <>
struct is_floating_point<paddle::platform::float16>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::float16,
                             typename std::remove_cv<
                                 paddle::platform::float16>::type>::value> {};
template <>
struct is_signed<paddle::platform::float16> {
  static const bool value = true;
};

template <>
struct is_unsigned<paddle::platform::float16> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::float16& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::float16& a) {
  return paddle::platform::isinf(a);
}

911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965
template <>
struct numeric_limits<paddle::platform::float16> {
  static const bool is_specialized = true;
  static const bool is_signed = true;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = true;
  static const bool has_quiet_NaN = true;
  static const bool has_signaling_NaN = true;
  static const float_denorm_style has_denorm = denorm_present;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_to_nearest;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 11;
  static const int digits10 = 3;
  static const int max_digits10 = 5;
  static const int radix = 2;
  static const int min_exponent = -13;
  static const int min_exponent10 = -4;
  static const int max_exponent = 16;
  static const int max_exponent10 = 4;
  static const bool traps = true;
  static const bool tinyness_before = false;

  static paddle::platform::float16(min)() {
    return paddle::platform::raw_uint16_to_float16(0x400);
  }
  static paddle::platform::float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  static paddle::platform::float16(max)() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  static paddle::platform::float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  static paddle::platform::float16 round_error() {
    return paddle::platform::float16(0.5);
  }
  static paddle::platform::float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  static paddle::platform::float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 signaling_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7e00);
  }
  static paddle::platform::float16 denorm_min() {
    return paddle::platform::raw_uint16_to_float16(0x1);
  }
};

K
kexinzhao 已提交
966
}  // namespace std
967 968

namespace Eigen {
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998

using float16 = paddle::platform::float16;

template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
  enum {
    IsSigned = true,
    IsInteger = false,
    IsComplex = false,
    RequireInitialization = false
  };

  HOSTDEVICE static inline float16 epsilon() {
    return paddle::platform::raw_uint16_to_float16(0x0800);
  }
  HOSTDEVICE static inline float16 dummy_precision() { return float16(1e-2f); }
  HOSTDEVICE static inline float16 highest() {
    return paddle::platform::raw_uint16_to_float16(0x7bff);
  }
  HOSTDEVICE static inline float16 lowest() {
    return paddle::platform::raw_uint16_to_float16(0xfbff);
  }
  HOSTDEVICE static inline float16 infinity() {
    return paddle::platform::raw_uint16_to_float16(0x7c00);
  }
  HOSTDEVICE static inline float16 quiet_NaN() {
    return paddle::platform::raw_uint16_to_float16(0x7c01);
  }
};

999 1000 1001
namespace numext {

template <>
1002
HOSTDEVICE inline bool(isnan)(const float16& a) {
1003 1004 1005 1006
  return (paddle::platform::isnan)(a);
}

template <>
1007
HOSTDEVICE inline bool(isinf)(const float16& a) {
1008 1009 1010 1011
  return (paddle::platform::isinf)(a);
}

template <>
1012
HOSTDEVICE inline bool(isfinite)(const float16& a) {
1013 1014 1015
  return (paddle::platform::isfinite)(a);
}

1016 1017 1018
template <>
HOSTDEVICE inline float16 exp(const float16& a) {
  return float16(::expf(static_cast<float>(a)));
C
Clementine 已提交
1019 1020 1021 1022 1023
}

template <>
HOSTDEVICE inline float16 erf(const float16& a) {
  return float16(::erff(static_cast<float>(a)));
1024 1025
}

1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
template <>
HOSTDEVICE inline float16 log(const float16& a) {
  return float16(::logf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
  return float16(::tanhf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
  return float16(::sqrtf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
  return float16(::ceilf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 floor(const float16& a) {
  return float16(::floorf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 round(const float16& a) {
  return float16(::roundf(static_cast<float>(a)));
}

template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
  return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}

template <>
HOSTDEVICE inline float16 abs(const float16& a) {
  return float16(::fabs(static_cast<float>(a)));
}

1066
}  // namespace numext
1067

1068
}  // namespace Eigen