float16.h 20.8 KB
Newer Older
K
Kexin Zhao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>

K
Kexin Zhao 已提交
19
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
20
#include <cuda.h>
K
Kexin Zhao 已提交
21 22
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
23
#include "unsupported/Eigen/CXX11/Tensor"
K
Kexin Zhao 已提交
24

K
Kexin Zhao 已提交
25 26
#include "paddle/platform/hostdevice.h"

K
Kexin Zhao 已提交
27 28 29 30 31 32 33 34 35 36 37 38
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
39
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
40 41
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
42 43
#endif

K
Kexin Zhao 已提交
44
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
45 46 47 48 49
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
50
#include <arm_neon.h>
K
Kexin Zhao 已提交
51 52
#endif

K
Kexin Zhao 已提交
53 54 55
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
56 57
#endif

K
Kexin Zhao 已提交
58
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
59 60
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
61

K
Kexin Zhao 已提交
62
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
63 64 65

namespace paddle {

K
Kexin Zhao 已提交
66 67 68
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
69
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
70
struct PADDLE_ALIGN(2) float16 {
K
Kexin Zhao 已提交
71
public:
K
Kexin Zhao 已提交
72
  uint16_t x;
K
Kexin Zhao 已提交
73

K
Kexin Zhao 已提交
74
  HOSTDEVICE inline float16() : x(0) {}
K
Kexin Zhao 已提交
75

K
Kexin Zhao 已提交
76
  HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
K
Kexin Zhao 已提交
77 78

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
79
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
80 81 82 83 84 85 86 87
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&h)->x;
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
88
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
89

K
Kexin Zhao 已提交
90
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
91
  // __fp16 is a native half precision data type for arm cpu,
K
Kexin Zhao 已提交
92 93
  // float16_t is an alias for __fp16 in arm_fp16.h,
  // which is included in arm_neon.h.
K
Kexin Zhao 已提交
94 95
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
96 97 98
  }
#endif

K
Kexin Zhao 已提交
99 100 101 102
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
103

K
Kexin Zhao 已提交
104 105 106 107
#elif defined(PADDLE_NEON)
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
108

K
Kexin Zhao 已提交
109 110
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
111

K
Kexin Zhao 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
129

K
Kexin Zhao 已提交
130
#endif
K
Kexin Zhao 已提交
131 132
  }

K
Kexin Zhao 已提交
133
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
134

K
Kexin Zhao 已提交
135 136 137
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
138

K
Kexin Zhao 已提交
139
  HOSTDEVICE inline float16& operator=(const float16& rhs) {
K
Kexin Zhao 已提交
140 141 142 143 144
    x = rhs.x;
    return *this;
  }

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
145
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
146 147 148 149 150 151 152 153 154
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&rhs)->x;
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
155
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
156 157 158 159
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
160 161 162
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
163 164 165 166
    return *this;
  }
#endif

K
Kexin Zhao 已提交
167
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
168 169 170 171
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
172 173
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
174
    return *this;
K
Kexin Zhao 已提交
175 176
  }

K
Kexin Zhao 已提交
177 178
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
179 180 181
    return *this;
  }

K
Kexin Zhao 已提交
182 183
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
184 185 186
    return *this;
  }

K
Kexin Zhao 已提交
187 188
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
189 190 191
    return *this;
  }

K
Kexin Zhao 已提交
192 193
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
194 195 196
    return *this;
  }

K
Kexin Zhao 已提交
197 198
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
199 200 201
    return *this;
  }

K
Kexin Zhao 已提交
202 203
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
204 205 206
    return *this;
  }

K
Kexin Zhao 已提交
207 208
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
209 210 211
    return *this;
  }

K
Kexin Zhao 已提交
212 213
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
214 215 216
    return *this;
  }

K
Kexin Zhao 已提交
217 218
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
219
    return *this;
K
Kexin Zhao 已提交
220
  }
K
Kexin Zhao 已提交
221 222

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
223
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
224 225 226 227 228 229 230 231 232 233 234
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
235

K
Kexin Zhao 已提交
236
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
237 238 239 240 241
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
242 243 244
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
245 246 247
  }
#endif

K
Kexin Zhao 已提交
248 249 250 251 252 253 254 255
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

#elif defined(PADDLE_NEON)
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
256

K
Kexin Zhao 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
280 281
  }

K
Kexin Zhao 已提交
282 283 284 285
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(float(*this));
K
Kexin Zhao 已提交
286 287
  }

K
Kexin Zhao 已提交
288 289
  HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(float(*this));
K
Kexin Zhao 已提交
290 291
  }

K
Kexin Zhao 已提交
292 293
  HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(float(*this));
K
Kexin Zhao 已提交
294 295
  }

K
Kexin Zhao 已提交
296 297
  HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(float(*this));
K
Kexin Zhao 已提交
298 299
  }

K
Kexin Zhao 已提交
300 301
  HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(float(*this));
K
Kexin Zhao 已提交
302 303
  }

K
Kexin Zhao 已提交
304 305
  HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(float(*this));
K
Kexin Zhao 已提交
306 307
  }

K
Kexin Zhao 已提交
308 309
  HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(float(*this));
K
Kexin Zhao 已提交
310 311
  }

K
Kexin Zhao 已提交
312 313
  HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(float(*this));
K
Kexin Zhao 已提交
314 315
  }

K
Kexin Zhao 已提交
316 317
  HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(float(*this));
K
Kexin Zhao 已提交
318
  }
K
Kexin Zhao 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348

private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
349 350
};

K
Kexin Zhao 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && \
    __CUDA_ARCH__ >= 530 && CUDA_VERSION < 9000
DEVICE inline half operator+(const half& a, const half& b) {
  return __hadd(a, b);
}

DEVICE inline half operator-(const half& a, const half& b) {
  return __hsub(a, b);
}

DEVICE inline half operator*(const half& a, const half& b) {
  return __hmul(a, b);
}

DEVICE inline half operator/(const half& a, const half& b) {
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
}

DEVICE inline half operator-(const half& a) { return __hneg(a); }

DEVICE inline half& operator+=(half& a, const half& b) {
  a = a + b;
  return a;
}

DEVICE inline half& operator-=(half& a, const half& b) {
  a = a - b;
  return a;
}

DEVICE inline half& operator*=(half& a, const half& b) {
  a = a * b;
  return a;
}

DEVICE inline half& operator/=(half& a, const half& b) {
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
  return __heq(a, b);
}

DEVICE inline bool operator!=(const half& a, const half& b) {
  return __hne(a, b);
}

DEVICE inline bool operator<(const half& a, const half& b) {
  return __hlt(a, b);
}

DEVICE inline bool operator<=(const half& a, const half& b) {
  return __hle(a, b);
}

DEVICE inline bool operator>(const half& a, const half& b) {
  return __hgt(a, b);
}

DEVICE inline bool operator>=(const half& a, const half& b) {
  return __hge(a, b);
}

/*
DEVICE inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
424
  return float16(__hadd(half(a), half(b)));
K
Kexin Zhao 已提交
425 426
}

K
Kexin Zhao 已提交
427
DEVICE inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
428
  return float16(__hsub(half(a), half(b)));
K
Kexin Zhao 已提交
429 430
}

K
Kexin Zhao 已提交
431
DEVICE inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
432
  return float16(__hmul(half(a), half(b)));
K
Kexin Zhao 已提交
433 434
}

K
Kexin Zhao 已提交
435
DEVICE inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
436 437 438 439
  float num = __half2float(half(a));
  float denom = __half2float(half(b));
  return float16(num / denom);
}
K
Kexin Zhao 已提交
440

K
Kexin Zhao 已提交
441
DEVICE inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
442 443 444
  return float16(__hneg(half(a)));
}

K
Kexin Zhao 已提交
445
DEVICE inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
446 447 448 449
  a = a + b;
  return a;
}

K
Kexin Zhao 已提交
450
DEVICE inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
451 452 453 454
  a = a - b;
  return a;
}

K
Kexin Zhao 已提交
455
DEVICE inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
456 457 458 459
  a = a * b;
  return a;
}

K
Kexin Zhao 已提交
460
DEVICE inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
461 462 463 464
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
465
DEVICE inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
466 467 468
  return __heq(half(a), half(b));
}

K
Kexin Zhao 已提交
469
DEVICE inline bool operator!=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
470 471 472
  return __hne(half(a), half(b));
}

K
Kexin Zhao 已提交
473
DEVICE inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
474 475 476
  return __hlt(half(a), half(b));
}

K
Kexin Zhao 已提交
477
DEVICE inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
478 479 480
  return __hle(half(a), half(b));
}

K
Kexin Zhao 已提交
481
DEVICE inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
482 483 484
  return __hgt(half(a), half(b));
}

K
Kexin Zhao 已提交
485
DEVICE inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
486 487
  return __hge(half(a), half(b));
}
K
Kexin Zhao 已提交
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
*/

// Arithmetic operators on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator-(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator*(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator/(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
559

K
Kexin Zhao 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
HOST inline float16 operator-(const float16& a) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

HOST inline float16& operator+=(float16& a, const float16& b) {
  a = a + b;
  return a;
}

HOST inline float16& operator-=(float16& a, const float16& b) {
  a = a - b;
  return a;
}

HOST inline float16& operator*=(float16& a, const float16& b) {
  a = a * b;
  return a;
}

HOST inline float16& operator/=(float16& a, const float16& b) {
  a = a / b;
  return a;
}

HOST inline bool operator==(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator!=(const float16& a, const float16& b) {
  return !(a == b);
}

HOST inline bool operator<(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator<=(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator>(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator>=(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

/*
HOST inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
686 687 688
  return float16(vaddh_f16(float16_t(a), float16_t(b)));
}

K
Kexin Zhao 已提交
689
HOST inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
690 691 692
  return float16(vsubh_f16(float16_t(a), float16_t(b)));
}

K
Kexin Zhao 已提交
693
HOST inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
694 695 696
  return float16(vmulh_f16(float16_t(a), float16_t(b)));
}

K
Kexin Zhao 已提交
697
HOST inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
698 699 700
  return float16(vdivh_f16(float16_t(a), float16_t(b)));
}

K
Kexin Zhao 已提交
701
HOST inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
702 703 704
  return float16(vnegh_f16(float16_t(a)));
}

K
Kexin Zhao 已提交
705
HOST inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
706 707 708 709
  a = a + b;
  return a;
}

K
Kexin Zhao 已提交
710
HOST inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
711 712 713 714
  a = a - b;
  return a;
}

K
Kexin Zhao 已提交
715
HOST inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
716 717 718 719
  a = a * b;
  return a;
}

K
Kexin Zhao 已提交
720
HOST inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
721 722 723 724
  a = a / b;
  return a;
}

K
Kexin Zhao 已提交
725
HOST inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
726 727 728
  return static_cast<bool>(vceqh_f16(float16_t(a), float16_t(b)));
}

K
Kexin Zhao 已提交
729
HOST inline bool operator!=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
730 731 732
  return !(a == b);
}

K
Kexin Zhao 已提交
733
HOST inline bool operator<(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
734
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
735
  return static_cast<bool>(vclth_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
736 737 738
#else
  return float(a) < float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
739 740
}

K
Kexin Zhao 已提交
741
HOST inline bool operator<=(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
742
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
743
  return static_cast<bool>(vcleh_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
744 745 746
#else
  return float(a) <= float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
747 748
}

K
Kexin Zhao 已提交
749
HOST inline bool operator>(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
750
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
751
  return static_cast<bool>(vcgth_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
752 753 754
#else
  return float(a) > float(b);
#endif  // PADDLE_NEON_64
K
Kexin Zhao 已提交
755 756
}

K
Kexin Zhao 已提交
757
HOST inline bool operator>=(const float16& a, const float16& b) {
K
fix bug  
Kexin Zhao 已提交
758
#ifdef PADDLE_NEON_64
K
Kexin Zhao 已提交
759
  return static_cast<bool>(vcgeh_f16(float16_t(a), float16_t(b)));
K
fix bug  
Kexin Zhao 已提交
760 761
#else
  return float(a) >= float(b);
K
Kexin Zhao 已提交
762
#endif  // PADDLE_NEON_64
K
fix bug  
Kexin Zhao 已提交
763
}
K
Kexin Zhao 已提交
764
*/
K
Kexin Zhao 已提交
765

K
Kexin Zhao 已提交
766 767 768
// Arithmetic operators, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
769 770 771
  return float16(float(a) + float(b));
}

K
Kexin Zhao 已提交
772
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
773 774 775
  return float16(float(a) - float(b));
}

K
Kexin Zhao 已提交
776
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
777 778 779
  return float16(float(a) * float(b));
}

K
Kexin Zhao 已提交
780
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
781 782 783
  return float16(float(a) / float(b));
}

K
Kexin Zhao 已提交
784
HOSTDEVICE inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
785 786 787 788 789
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

K
Kexin Zhao 已提交
790
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
791 792 793 794
  a = float16(float(a) + float(b));
  return a;
}

K
Kexin Zhao 已提交
795
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
796 797 798 799
  a = float16(float(a) - float(b));
  return a;
}

K
Kexin Zhao 已提交
800
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
801 802 803 804
  a = float16(float(a) * float(b));
  return a;
}

K
Kexin Zhao 已提交
805
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
806 807 808 809
  a = float16(float(a) / float(b));
  return a;
}

K
Kexin Zhao 已提交
810
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
811 812 813
  return float(a) == float(b);
}

K
Kexin Zhao 已提交
814
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
815 816 817
  return float(a) != float(b);
}

K
Kexin Zhao 已提交
818
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
819 820 821
  return float(a) < float(b);
}

K
Kexin Zhao 已提交
822
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
823 824 825
  return float(a) <= float(b);
}

K
Kexin Zhao 已提交
826
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
827 828 829
  return float(a) > float(b);
}

K
Kexin Zhao 已提交
830
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
831 832
  return float(a) >= float(b);
}
K
Kexin Zhao 已提交
833 834

#endif
K
Kexin Zhao 已提交
835
}  // namespace paddle