float16.h 18.8 KB
Newer Older
K
Kexin Zhao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>

K
Kexin Zhao 已提交
19
#include <cuda.h>
K
Kexin Zhao 已提交
20
#include "unsupported/Eigen/CXX11/Tensor"
K
Kexin Zhao 已提交
21

K
Kexin Zhao 已提交
22 23 24 25 26 27 28 29 30 31 32 33
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
34 35 36 37 38 39 40 41
#ifdef __CUDACC__
#define PADDLE_HOSTDEVICE __host__ __device__
#if CUDA_VERSION >= 7050
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#endif  // CUDA_VERSION >= 7050
#else
#define PADDLE_HOSTDEVICE
K
Kexin Zhao 已提交
42
#endif  // __CUDACC__
K
Kexin Zhao 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

#ifdef __arm__
#define PADDLE_ARM_32
#endif

#ifdef __aarch64__
#define PADDLE_ARM_64
#endif

#if defined(PADDLE_ARM_32) || defined(PADDLE_ARM_64)
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
58
#include <arm_neon.h>
K
Kexin Zhao 已提交
59 60 61 62 63 64 65 66 67 68
#endif

#if defined(PADDLE_NEON) && defined(PADDLE_ARM_32)
#define PADDLE_NEON_32
#endif

#if defined(PADDLE_NEON) && defined(PADDLE_ARM_64)
#define PADDLE_NEON_64
#endif

K
Kexin Zhao 已提交
69
#ifdef PADDLE_ARM
K
Kexin Zhao 已提交
70 71
#ifdef __F16C__
#undef __F16C__
K
Kexin Zhao 已提交
72 73 74 75
#endif  // __F16C__
#else
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
76

K
Kexin Zhao 已提交
77
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
78 79 80

namespace paddle {

K
Kexin Zhao 已提交
81
struct float16;
K
Kexin Zhao 已提交
82

K
Kexin Zhao 已提交
83
namespace fp16_impl {
K
Kexin Zhao 已提交
84
// Convert from float to half precision in round-to-nearest-even mode
K
Kexin Zhao 已提交
85 86 87
PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f);
PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
}  // namespace fp16_impl
K
Kexin Zhao 已提交
88

K
Kexin Zhao 已提交
89 90 91
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
92
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
93 94
struct PADDLE_ALIGN(2) float16 {
  uint16_t x;
K
Kexin Zhao 已提交
95

K
Kexin Zhao 已提交
96
  PADDLE_HOSTDEVICE inline float16() : x(0) {}
K
Kexin Zhao 已提交
97 98 99 100

  PADDLE_HOSTDEVICE inline float16(const float16& h) : x(h.x) {}

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
101
  PADDLE_HOSTDEVICE inline float16(const half& h) {
K
Kexin Zhao 已提交
102 103 104 105 106 107 108 109 110 111
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&h)->x;
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

  PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {}

K
Kexin Zhao 已提交
112 113
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
K
Kexin Zhao 已提交
114
  // __fp16 is a native half precision data type for arm cpu,
K
Kexin Zhao 已提交
115 116
  // float16_t is an alias for __fp16 in arm_fp16.h,
  // which is included in arm_neon.h.
K
Kexin Zhao 已提交
117
  PADDLE_HOSTDEVICE inline float16(const float16_t& h) {
K
fix bug  
Kexin Zhao 已提交
118 119
    float16_t tmp = h;
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
120 121 122 123
  }
#endif

  PADDLE_HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
124

K
Kexin Zhao 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  PADDLE_HOSTDEVICE inline explicit float16(int8_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(uint8_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(int16_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(uint16_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(int32_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(uint32_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(int64_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

  PADDLE_HOSTDEVICE inline explicit float16(uint64_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
  }

K
Kexin Zhao 已提交
165
  PADDLE_HOSTDEVICE inline explicit float16(float val) {
K
Kexin Zhao 已提交
166 167 168 169
    float16 res = fp16_impl::float_to_half_rn(val);
    x = res.x;
  }

K
Kexin Zhao 已提交
170
  PADDLE_HOSTDEVICE inline explicit float16(double val) {
K
Kexin Zhao 已提交
171
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
K
Kexin Zhao 已提交
172 173 174
    x = res.x;
  }

K
Kexin Zhao 已提交
175 176 177 178 179 180
  PADDLE_HOSTDEVICE inline float16& operator=(const float16& rhs) {
    x = rhs.x;
    return *this;
  }

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
181
  PADDLE_HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&rhs)->x;
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

  PADDLE_HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
196 197
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
K
Kexin Zhao 已提交
198
  PADDLE_HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
K
fix bug  
Kexin Zhao 已提交
199 200
    float16_t tmp = rhs;
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
201 202 203 204
    return *this;
  }
#endif

K
Kexin Zhao 已提交
205 206 207 208 209 210
  PADDLE_HOSTDEVICE inline float16& operator=(bool b) {
    x = b ? 0x3c00 : 0;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(int8_t val) {
K
Kexin Zhao 已提交
211
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
K
Kexin Zhao 已提交
212
    x = res.x;
K
Kexin Zhao 已提交
213
    return *this;
K
Kexin Zhao 已提交
214 215
  }

K
Kexin Zhao 已提交
216
  PADDLE_HOSTDEVICE inline float16& operator=(uint8_t val) {
K
Kexin Zhao 已提交
217
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
K
Kexin Zhao 已提交
218
    x = res.x;
K
Kexin Zhao 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(int16_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(uint16_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(int32_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(uint32_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(int64_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(uint64_t val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(float val) {
    float16 res = fp16_impl::float_to_half_rn(val);
    x = res.x;
    return *this;
  }

  PADDLE_HOSTDEVICE inline float16& operator=(double val) {
    float16 res = fp16_impl::float_to_half_rn(static_cast<float>(val));
    x = res.x;
    return *this;
K
Kexin Zhao 已提交
268
  }
K
Kexin Zhao 已提交
269 270

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
271
  PADDLE_HOSTDEVICE inline operator half() const {
K
Kexin Zhao 已提交
272 273 274 275 276 277 278 279 280 281 282
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
283

K
Kexin Zhao 已提交
284
  PADDLE_HOSTDEVICE inline operator Eigen::half() const {
K
Kexin Zhao 已提交
285 286 287 288 289
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
290 291
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
K
Kexin Zhao 已提交
292
  PADDLE_HOSTDEVICE inline operator float16_t() const {
K
Kexin Zhao 已提交
293 294 295 296 297
    float16 h = *this;
    return *reinterpret_cast<float16_t*>(&h);
  }
#endif

K
Kexin Zhao 已提交
298
  PADDLE_HOSTDEVICE inline explicit operator bool() const {
K
Kexin Zhao 已提交
299 300 301
    return (x & 0x7fff) != 0;
  }

K
Kexin Zhao 已提交
302 303
  PADDLE_HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
304 305
  }

K
Kexin Zhao 已提交
306 307
  PADDLE_HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
308 309
  }

K
Kexin Zhao 已提交
310 311
  PADDLE_HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
312 313
  }

K
Kexin Zhao 已提交
314 315
  PADDLE_HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
316 317
  }

K
Kexin Zhao 已提交
318 319
  PADDLE_HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
320 321
  }

K
Kexin Zhao 已提交
322 323
  PADDLE_HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
324 325
  }

K
Kexin Zhao 已提交
326 327
  PADDLE_HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
328 329
  }

K
Kexin Zhao 已提交
330 331
  PADDLE_HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
332 333
  }

K
Kexin Zhao 已提交
334
  PADDLE_HOSTDEVICE inline explicit operator float() const {
K
Kexin Zhao 已提交
335 336 337
    return fp16_impl::half_to_float(*this);
  }

K
Kexin Zhao 已提交
338 339
  PADDLE_HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(fp16_impl::half_to_float(*this));
K
Kexin Zhao 已提交
340 341 342
  }
};

K
Kexin Zhao 已提交
343
// Arithmetic operators
K
Kexin Zhao 已提交
344 345
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__ inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
346
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
347 348 349
}

__device__ inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
350
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
351 352 353
}

__device__ inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
354
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
355 356
}

K
Kexin Zhao 已提交
357 358 359 360 361 362
__device__ inline float16 operator/(const float16& a, const float16& b) {
  // TODO(kexinzhao): check the cuda version that starts to support __hdiv
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
}
K
Kexin Zhao 已提交
363

K
Kexin Zhao 已提交
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
__device__ inline float16 operator-(const float16& a) {
  return float16(__hneg(half(a)));
}

__device__ inline float16& operator+=(float16& a, const float16& b) {
  a = a + b;
  return a;
}

__device__ inline float16& operator-=(float16& a, const float16& b) {
  a = a - b;
  return a;
}

__device__ inline float16& operator*=(float16& a, const float16& b) {
  a = a * b;
  return a;
}

__device__ inline float16& operator/=(float16& a, const float16& b) {
  a = a / b;
  return a;
}

__device__ inline bool operator==(const float16& a, const float16& b) {
  return __heq(half(a), half(b));
}

__device__ inline bool operator!=(const float16& a, const float16& b) {
  return __hne(half(a), half(b));
}

__device__ inline bool operator<(const float16& a, const float16& b) {
  return __hlt(half(a), half(b));
}

__device__ inline bool operator<=(const float16& a, const float16& b) {
  return __hle(half(a), half(b));
}

__device__ inline bool operator>(const float16& a, const float16& b) {
  return __hgt(half(a), half(b));
}

__device__ inline bool operator>=(const float16& a, const float16& b) {
  return __hge(half(a), half(b));
}

// On ARMv8.2-A CPU
K
Kexin Zhao 已提交
413
#elif defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
K
Kexin Zhao 已提交
414
    (PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39)
K
Kexin Zhao 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
__host__ inline float16 operator+(const float16& a, const float16& b) {
  return float16(vaddh_f16(float16_t(a), float16_t(b)));
}

__host__ inline float16 operator-(const float16& a, const float16& b) {
  return float16(vsubh_f16(float16_t(a), float16_t(b)));
}

__host__ inline float16 operator*(const float16& a, const float16& b) {
  return float16(vmulh_f16(float16_t(a), float16_t(b)));
}

__host__ inline float16 operator/(const float16& a, const float16& b) {
  return float16(vdivh_f16(float16_t(a), float16_t(b)));
}

__host__ inline float16 operator-(const float16& a) {
  return float16(vnegh_f16(float16_t(a)));
}

__host__ inline float16& operator+=(float16& a, const float16& b) {
  a = a + b;
  return a;
}

__host__ inline float16& operator-=(float16& a, const float16& b) {
  a = a - b;
  return a;
}

__host__ inline float16& operator*=(float16& a, const float16& b) {
  a = a * b;
  return a;
}

__host__ inline float16& operator/=(float16& a, const float16& b) {
  a = a / b;
  return a;
}

__host__ inline bool operator==(const float16& a, const float16& b) {
  return static_cast<bool>(vceqh_f16(float16_t(a), float16_t(b)));
}

__host__ inline bool operator!=(const float16& a, const float16& b) {
  return !(a == b);
}

__host__ inline bool operator<(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
464
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
465
  return static_cast<bool>(vclth_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
466 467 468
#else
  return float(a) < float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
469 470 471
}

__host__ inline bool operator<=(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
472
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
473
  return static_cast<bool>(vcleh_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
474 475 476
#else
  return float(a) <= float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
477 478 479
}

__host__ inline bool operator>(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
480
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
481
  return static_cast<bool>(vcgth_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
482 483 484
#else
  return float(a) > float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
485 486 487
}

__host__ inline bool operator>=(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
488
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
489
  return static_cast<bool>(vcgeh_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
490 491
#else
  return float(a) >= float(b);
K
Kexin Zhao 已提交
492
#endif  // PADDLE_NEON_64
K
fix bug  
Kexin Zhao 已提交
493
}
K
Kexin Zhao 已提交
494

K
Kexin Zhao 已提交
495
#else  // Software emulation on other cpu
K
Kexin Zhao 已提交
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
PADDLE_HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
  return float16(float(a) + float(b));
}

PADDLE_HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
  return float16(float(a) - float(b));
}

PADDLE_HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
  return float16(float(a) * float(b));
}

PADDLE_HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
  return float16(float(a) / float(b));
}

PADDLE_HOSTDEVICE inline float16 operator-(const float16& a) {
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

PADDLE_HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
  a = float16(float(a) + float(b));
  return a;
}

PADDLE_HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
  a = float16(float(a) - float(b));
  return a;
}

PADDLE_HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
  a = float16(float(a) * float(b));
  return a;
}

PADDLE_HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
  a = float16(float(a) / float(b));
  return a;
}

PADDLE_HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
  return float(a) == float(b);
}

PADDLE_HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
  return float(a) != float(b);
}

PADDLE_HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
  return float(a) < float(b);
}

PADDLE_HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
  return float(a) <= float(b);
}

PADDLE_HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
  return float(a) > float(b);
}

PADDLE_HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
  return float(a) >= float(b);
}
K
Kexin Zhao 已提交
561 562

#endif
K
Kexin Zhao 已提交
563 564 565

namespace fp16_impl {

K
Kexin Zhao 已提交
566
union Bits {
K
Kexin Zhao 已提交
567 568 569 570 571
  float f;
  int32_t si;
  uint32_t ui;
};

K
Kexin Zhao 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585
const int shift = 13;
const int shiftSign = 16;

const int32_t infN = 0x7F800000;
const int32_t maxN = 0x477FE000;  // max flt16 as flt32
const int32_t minN = 0x38800000;  // min flt16 normal as flt32
const int32_t sigN = 0x80000000;  // sign bit

constexpr int32_t infC = infN >> shift;
constexpr int32_t nanN = (infC + 1) << shift;  // minimum flt16 nan as float32
constexpr int32_t maxC = maxN >> shift;
constexpr int32_t minC = minN >> shift;
constexpr int32_t sigC = sigN >> shiftSign;

K
Kexin Zhao 已提交
586
const int32_t mulN = 0x52000000;  // (1 << 23) / minN
K
Kexin Zhao 已提交
587 588 589 590 591 592 593 594 595 596
const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
const int32_t norC = 0x00400;     // min flt32 normal downshifted

constexpr int32_t maxD = infC - maxC - 1;
constexpr int32_t minD = minC - subC - 1;

PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  half tmp = __float2half(f);
K
Kexin Zhao 已提交
597
  return *reinterpret_cast<float16*>(&tmp);
K
Kexin Zhao 已提交
598

K
Kexin Zhao 已提交
599
#elif defined(PADDLE_NEON_64)
K
Kexin Zhao 已提交
600 601 602
  float16 res;
  asm volatile(
      "ld1 {v0.s}[0], [%[float_ptr]]\n"
K
Kexin Zhao 已提交
603
      "fcvt h0, s0\n"
K
Kexin Zhao 已提交
604 605 606 607 608 609 610 611
      "st1 {v0.h}[0], [%[half_ptr]]\n"
      :  // outputs
      :  // inputs
      [float_ptr] "r"(&f),
      [half_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
K
Kexin Zhao 已提交
612

K
Kexin Zhao 已提交
613
#elif defined(PADDLE_NEON_32)
K
Kexin Zhao 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
  float16 res;
  asm volatile(
      "vld1.32 {d0[0]}, [%[float_ptr]]\n"
      "vcvt.f16.f32 d0, q0\n"
      "vst1.16 {d0[0]}, [%[half_ptr]]\n"
      :  // outputs
      :  // inputs
      [float_ptr] "r"(&f),
      [half_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "d0");
  return res;

#elif defined(__F16C__)
  float16 res;
  res.x = _cvtss_sh(f, 0);
  return res;

K
Kexin Zhao 已提交
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
#else
  // Conversion routine adapted from
  // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
  Bits v, s;
  v.f = f;
  uint32_t sign = v.si & sigN;
  v.si ^= sign;
  sign >>= shiftSign;  // logical shift
  s.si = mulN;
  s.si = s.f * v.f;  // correct subnormals
  v.si ^= (s.si ^ v.si) & -(minN > v.si);
  v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
  v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
  v.ui >>= shift;  // logical shift
  v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
  v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
  float16 res;
  res.x = v.ui | sign;
  return res;
K
Kexin Zhao 已提交
651

K
Kexin Zhao 已提交
652 653
#endif
}
K
Kexin Zhao 已提交
654

K
Kexin Zhao 已提交
655 656 657
PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
  half tmp = *reinterpret_cast<half*>(&h);
K
Kexin Zhao 已提交
658
  return __half2float(tmp);
K
Kexin Zhao 已提交
659

K
Kexin Zhao 已提交
660
#elif defined(PADDLE_NEON_64)
K
Kexin Zhao 已提交
661 662 663
  float res;
  asm volatile(
      "ld1 {v0.h}[0], [%[half_ptr]]\n"
K
Kexin Zhao 已提交
664
      "fcvt s0, h0\n"
K
Kexin Zhao 已提交
665 666 667 668 669 670 671 672
      "st1 {v0.s}[0], [%[float_ptr]]\n"
      :  // outputs
      :  // inputs
      [half_ptr] "r"(&(h.x)),
      [float_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0");
  return res;
K
Kexin Zhao 已提交
673

K
Kexin Zhao 已提交
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
#elif defined(PADDLE_NEON_32)
  float res;
  asm volatile(
      "vld1.16 {d0[0]}, [%[half_ptr]]\n"
      "vcvt.f32.f16 q0, d0\n"
      "vst1.32 {d0[0]}, [%[float_ptr]]\n"
      :  // outputs
      :  // inputs
      [half_ptr] "r"(&(h.x)),
      [float_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0");
  return res;

#elif defined(__F16C__)
  return _cvtsh_ss(h.x);

K
Kexin Zhao 已提交
691 692 693 694
#else
  // Conversion routine adapted from
  // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
  Bits v;
K
Kexin Zhao 已提交
695
  v.ui = h.x;
K
Kexin Zhao 已提交
696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
  int32_t sign = v.si & sigC;
  v.si ^= sign;
  sign <<= shiftSign;
  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
  Bits s;
  s.si = mulC;
  s.f *= v.si;
  int32_t mask = -(norC > v.si);
  v.si <<= shift;
  v.si ^= (s.si ^ v.si) & mask;
  v.si |= sign;
  return v.f;

#endif
}

K
Kexin Zhao 已提交
713
}  // namespace fp16_impl
K
Kexin Zhao 已提交
714
}  // namespace paddle