float16.h 23.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
K
Kexin Zhao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
K
Kexin Zhao 已提交
18

K
Kexin Zhao 已提交
19
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
20
#include <cuda.h>
K
Kexin Zhao 已提交
21 22
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
23 24 25 26 27 28 29 30 31 32 33 34
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
35
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
36 37
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
38 39
#endif

K
Kexin Zhao 已提交
40
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
41 42 43 44 45
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
46
#include <arm_neon.h>
K
Kexin Zhao 已提交
47 48
#endif

K
Kexin Zhao 已提交
49 50 51
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
52 53
#endif

K
Kexin Zhao 已提交
54
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
55 56
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
57

K
Kexin Zhao 已提交
58
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
59 60

namespace paddle {
K
kexinzhao 已提交
61
namespace platform {
K
Kexin Zhao 已提交
62

63 64 65 66 67 68 69 70 71 72 73 74
// Forward declare float16 for eigen.h
struct float16;

}  // namespace platform
}  // namespace paddle

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace platform {

K
Kexin Zhao 已提交
75 76 77
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
78
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
79
struct PADDLE_ALIGN(2) float16 {
80
 public:
K
Kexin Zhao 已提交
81
  uint16_t x;
K
Kexin Zhao 已提交
82

K
kexinzhao 已提交
83 84
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
85 86 87 88 89 90
  float16() = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o) = default;
  float16& operator=(float16&& o) = default;
  ~float16() = default;
K
kexinzhao 已提交
91 92

// Constructors
K
Kexin Zhao 已提交
93
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
94
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
95
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
96
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
K
Kexin Zhao 已提交
97 98 99 100 101 102
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
103
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
104

K
Kexin Zhao 已提交
105
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
106
  // __fp16 is a native half precision data type for arm cpu,
107
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
108 109
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
110 111 112
  }
#endif

K
Kexin Zhao 已提交
113 114 115 116
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
117

118
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
119 120 121
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
122

K
Kexin Zhao 已提交
123 124
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
125

K
Kexin Zhao 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
143

K
Kexin Zhao 已提交
144
#endif
K
Kexin Zhao 已提交
145 146
  }

K
Kexin Zhao 已提交
147
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
148

K
Kexin Zhao 已提交
149 150 151
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
152

153
// Assignment operators
K
Kexin Zhao 已提交
154
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
155
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
156
#if CUDA_VERSION >= 9000
Y
Yu Yang 已提交
157
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
K
Kexin Zhao 已提交
158 159 160 161 162 163 164
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
165
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
166 167 168 169
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
170 171 172
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
173 174 175 176
    return *this;
  }
#endif

K
Kexin Zhao 已提交
177
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
178 179 180 181
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
182 183
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
184
    return *this;
K
Kexin Zhao 已提交
185 186
  }

K
Kexin Zhao 已提交
187 188
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
189 190 191
    return *this;
  }

K
Kexin Zhao 已提交
192 193
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
194 195 196
    return *this;
  }

K
Kexin Zhao 已提交
197 198
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
199 200 201
    return *this;
  }

K
Kexin Zhao 已提交
202 203
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
204 205 206
    return *this;
  }

K
Kexin Zhao 已提交
207 208
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
209 210 211
    return *this;
  }

K
Kexin Zhao 已提交
212 213
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
214 215 216
    return *this;
  }

K
Kexin Zhao 已提交
217 218
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
219 220 221
    return *this;
  }

K
Kexin Zhao 已提交
222 223
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
224 225 226
    return *this;
  }

K
Kexin Zhao 已提交
227 228
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
229
    return *this;
K
Kexin Zhao 已提交
230
  }
K
Kexin Zhao 已提交
231

232
// Conversion opertors
K
Kexin Zhao 已提交
233
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
234
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
235 236 237 238 239 240 241 242 243 244 245
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
246

K
Kexin Zhao 已提交
247
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
248 249 250 251 252
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
253 254 255
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
256 257 258
  }
#endif

K
Kexin Zhao 已提交
259 260 261 262 263
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

264
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
265 266
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
267

K
Kexin Zhao 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
291 292
  }

K
Kexin Zhao 已提交
293 294 295 296
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(float(*this));
K
Kexin Zhao 已提交
297 298
  }

K
Kexin Zhao 已提交
299 300
  HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(float(*this));
K
Kexin Zhao 已提交
301 302
  }

K
Kexin Zhao 已提交
303 304
  HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(float(*this));
K
Kexin Zhao 已提交
305 306
  }

K
Kexin Zhao 已提交
307 308
  HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(float(*this));
K
Kexin Zhao 已提交
309 310
  }

K
Kexin Zhao 已提交
311 312
  HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(float(*this));
K
Kexin Zhao 已提交
313 314
  }

K
Kexin Zhao 已提交
315 316
  HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(float(*this));
K
Kexin Zhao 已提交
317 318
  }

K
Kexin Zhao 已提交
319 320
  HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(float(*this));
K
Kexin Zhao 已提交
321 322
  }

K
Kexin Zhao 已提交
323 324
  HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(float(*this));
K
Kexin Zhao 已提交
325 326
  }

K
Kexin Zhao 已提交
327 328
  HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(float(*this));
K
Kexin Zhao 已提交
329
  }
K
Kexin Zhao 已提交
330

331
 private:
K
Kexin Zhao 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
360 361
};

K
Kexin Zhao 已提交
362 363 364 365 366
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
367 368
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
369
DEVICE inline half operator+(const half& a, const half& b) {
370
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
371
  return __hadd(a, b);
372 373 374 375
#else
  float res = float(float16(a)) + float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
376 377 378
}

DEVICE inline half operator-(const half& a, const half& b) {
379
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
380
  return __hsub(a, b);
381 382 383 384
#else
  float res = float(float16(a)) - float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
385 386 387
}

DEVICE inline half operator*(const half& a, const half& b) {
388
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
389
  return __hmul(a, b);
390 391 392 393
#else
  float res = float(float16(a)) * float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
394 395 396
}

DEVICE inline half operator/(const half& a, const half& b) {
397
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
398 399 400
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
401 402 403 404
#else
  float res = float(float16(a)) / float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
405 406
}

407 408 409 410 411 412 413 414
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
  float res = -float(float16(a));
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436

DEVICE inline half& operator+=(half& a, const half& b) {
  a = a + b;
  return a;
}

DEVICE inline half& operator-=(half& a, const half& b) {
  a = a - b;
  return a;
}

DEVICE inline half& operator*=(half& a, const half& b) {
  a = a * b;
  return a;
}

DEVICE inline half& operator/=(half& a, const half& b) {
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
437
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
438
  return __heq(a, b);
439 440 441
#else
  return float(float16(a)) == float(float16(b));
#endif
K
Kexin Zhao 已提交
442 443 444
}

DEVICE inline bool operator!=(const half& a, const half& b) {
445
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
446
  return __hne(a, b);
447 448 449
#else
  return float(float16(a)) != float(float16(b));
#endif
K
Kexin Zhao 已提交
450 451 452
}

DEVICE inline bool operator<(const half& a, const half& b) {
453
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
454
  return __hlt(a, b);
455 456 457
#else
  return float(float16(a)) < float(float16(b));
#endif
K
Kexin Zhao 已提交
458 459 460
}

DEVICE inline bool operator<=(const half& a, const half& b) {
461
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
462
  return __hle(a, b);
463 464 465
#else
  return float(float16(a)) <= float(float16(b));
#endif
K
Kexin Zhao 已提交
466 467 468
}

DEVICE inline bool operator>(const half& a, const half& b) {
469
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
470
  return __hgt(a, b);
471 472 473
#else
  return float(float16(a)) > float(float16(b));
#endif
K
Kexin Zhao 已提交
474 475 476
}

DEVICE inline bool operator>=(const half& a, const half& b) {
477
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
478
  return __hge(a, b);
479 480 481
#else
  return float(float16(a)) >= float(float16(b));
#endif
K
Kexin Zhao 已提交
482 483
}

484
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
485

486
// Arithmetic operators for float16 on GPU
K
Kexin Zhao 已提交
487 488 489
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
490
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
491 492
#else
  return float16(float(a) + float(b));
K
Kexin Zhao 已提交
493
#endif
494 495
}

K
Kexin Zhao 已提交
496 497
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
498
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
499 500
#else
  return float16(float(a) - float(b));
K
Kexin Zhao 已提交
501
#endif
502 503
}

K
Kexin Zhao 已提交
504 505
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
506
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
507 508
#else
  return float16(float(a) * float(b));
K
Kexin Zhao 已提交
509
#endif
510 511
}

K
Kexin Zhao 已提交
512 513 514
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
515 516 517
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
K
Kexin Zhao 已提交
518 519
#else
  return float16(float(a) / float(b));
K
Kexin Zhao 已提交
520
#endif
521 522
}

K
Kexin Zhao 已提交
523 524
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
525
  return float16(__hneg(half(a)));
K
Kexin Zhao 已提交
526 527 528 529
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
K
Kexin Zhao 已提交
530
#endif
531 532
}

K
Kexin Zhao 已提交
533
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
534 535 536 537
  a = a + b;
  return a;
}

K
Kexin Zhao 已提交
538
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
539 540 541 542
  a = a - b;
  return a;
}

K
Kexin Zhao 已提交
543
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
544 545 546 547
  a = a * b;
  return a;
}

K
Kexin Zhao 已提交
548
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
549 550 551 552
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
553 554
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
555
  return __heq(half(a), half(b));
K
Kexin Zhao 已提交
556 557
#else
  return float(a) == float(b);
K
Kexin Zhao 已提交
558
#endif
559 560
}

K
Kexin Zhao 已提交
561 562
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
563
  return __hne(half(a), half(b));
K
Kexin Zhao 已提交
564 565
#else
  return float(a) != float(b);
K
Kexin Zhao 已提交
566
#endif
567 568
}

K
Kexin Zhao 已提交
569 570
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
571
  return __hlt(half(a), half(b));
K
Kexin Zhao 已提交
572 573
#else
  return float(a) < float(b);
K
Kexin Zhao 已提交
574
#endif
575 576
}

K
Kexin Zhao 已提交
577 578
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
579
  return __hle(half(a), half(b));
K
Kexin Zhao 已提交
580 581
#else
  return float(a) <= float(b);
K
Kexin Zhao 已提交
582
#endif
583 584
}

K
Kexin Zhao 已提交
585 586
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
587
  return __hgt(half(a), half(b));
K
Kexin Zhao 已提交
588 589
#else
  return float(a) > float(b);
K
Kexin Zhao 已提交
590
#endif
591 592
}

K
Kexin Zhao 已提交
593 594
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
595
  return __hge(half(a), half(b));
K
Kexin Zhao 已提交
596 597
#else
  return float(a) >= float(b);
K
Kexin Zhao 已提交
598
#endif
599 600 601 602
}

// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
603
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
604 605 606 607 608 609 610 611
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
612
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
613 614 615 616 617 618
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

619
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
620 621 622 623 624 625 626 627
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
628
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
629 630 631 632 633 634
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

635
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
636 637 638 639 640 641 642 643
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
644
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
645 646 647 648 649 650
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

651
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
652 653 654 655 656 657 658 659
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
660
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
661 662 663 664 665
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
666

667
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
668 669 670 671 672 673 674 675 676 677 678 679 680 681
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

682
inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
683 684 685 686
  a = a + b;
  return a;
}

687
inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
688 689 690 691
  a = a - b;
  return a;
}

692
inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
693 694 695 696
  a = a * b;
  return a;
}

697
inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
698 699 700 701
  a = a / b;
  return a;
}

702
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
703 704 705 706 707 708 709 710
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
711
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
712 713 714 715 716 717
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

718
inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
K
Kexin Zhao 已提交
719

720
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
721 722 723 724 725 726 727 728
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
729
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
730 731 732 733 734 735
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

736
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
737 738 739 740 741 742 743 744
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
745
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
746 747 748 749 750 751
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

752
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
753 754 755 756 757 758 759 760
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
761
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
762 763 764 765 766 767
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

768
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
769 770 771 772 773 774 775 776
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
777
      [a_ptr] "r"(&(a.x)), [b_ptr] "r"(&(b.x)),
K
Kexin Zhao 已提交
778 779 780 781 782 783
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

K
Kexin Zhao 已提交
784
// Arithmetic operators for float16, software emulated on other CPU
K
Kexin Zhao 已提交
785
#else
786
inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
787 788 789
  return float16(float(a) + float(b));
}

790
inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
791 792 793
  return float16(float(a) - float(b));
}

794
inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
795 796 797
  return float16(float(a) * float(b));
}

798
inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
799 800 801
  return float16(float(a) / float(b));
}

802
inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
803 804 805 806 807
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

808
inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
809 810 811 812
  a = float16(float(a) + float(b));
  return a;
}

813
inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
814 815 816 817
  a = float16(float(a) - float(b));
  return a;
}

818
inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
819 820 821 822
  a = float16(float(a) * float(b));
  return a;
}

823
inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
824 825 826 827
  a = float16(float(a) / float(b));
  return a;
}

828
inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
829 830 831
  return float(a) == float(b);
}

832
inline bool operator!=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
833 834 835
  return float(a) != float(b);
}

836
inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
837 838 839
  return float(a) < float(b);
}

840
inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
841 842 843
  return float(a) <= float(b);
}

844
inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
845 846 847
  return float(a) > float(b);
}

848
inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
849 850
  return float(a) >= float(b);
}
K
Kexin Zhao 已提交
851
#endif
K
kexinzhao 已提交
852

853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(half(a));
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

HOSTDEVICE inline bool(isinf)(const float16& a) {
  return (a.x & 0x7fff) == 0x7c00;
}

HOSTDEVICE inline bool(isfinite)(const float16& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

K
kexinzhao 已提交
869
}  // namespace platform
K
Kexin Zhao 已提交
870
}  // namespace paddle
K
kexinzhao 已提交
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889

namespace std {

// Override the std::is_pod::value for float16
// The reason is that different compilers implemented std::is_pod based on
// different C++ standards. float16 class is a plain old data in C++11 given
// that it is both trivial and standard_layout.
// However, std::is_pod in nvcc 8.0 host c++ compiler follows C++0x and is
// more restricted in that you cannot provide any customized
// constructor in float16. Hence, we override is_pod here following C++11
// so that .cu files can be successfully compiled by nvcc.
template <>
struct is_pod<paddle::platform::float16> {
  static const bool value =
      is_trivial<paddle::platform::float16>::value &&
      is_standard_layout<paddle::platform::float16>::value;
};

}  // namespace std
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913

namespace Eigen {
namespace numext {

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(
    const paddle::platform::float16& a) {
  return (paddle::platform::isnan)(a);
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(
    const paddle::platform::float16& a) {
  return (paddle::platform::isinf)(a);
}

template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(
    const paddle::platform::float16& a) {
  return (paddle::platform::isfinite)(a);
}

}  // namespace numext
}  // namespace Eigen