float16.h 18.5 KB
Newer Older
K
Kexin Zhao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <stdint.h>
K
Kexin Zhao 已提交
18

K
Kexin Zhao 已提交
19
#ifdef PADDLE_WITH_CUDA
K
Kexin Zhao 已提交
20
#include <cuda.h>
K
Kexin Zhao 已提交
21 22
#endif  // PADDLE_WITH_CUDA

K
Kexin Zhao 已提交
23
#include "unsupported/Eigen/CXX11/Tensor"
K
Kexin Zhao 已提交
24

K
Kexin Zhao 已提交
25 26
#include "paddle/platform/hostdevice.h"

K
Kexin Zhao 已提交
27 28 29 30 31 32 33 34 35 36 37 38
#ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else
#define PADDLE_GNUC_VER 0
#endif  // __GNUC__

#ifdef __clang__
#define PADDLE_CLANG_VER (__clang_major__ * 10 + __clang_minor__)
#else
#define PADDLE_CLANG_VER 0
#endif  // __clang__

K
Kexin Zhao 已提交
39
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
K
Kexin Zhao 已提交
40 41
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
K
Kexin Zhao 已提交
42 43
#endif

K
Kexin Zhao 已提交
44
#if defined(__arm__) || defined(__aarch64__)
K
Kexin Zhao 已提交
45 46 47 48 49
#define PADDLE_ARM
#endif

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#define PADDLE_NEON
K
Kexin Zhao 已提交
50
#include <arm_neon.h>
K
Kexin Zhao 已提交
51 52
#endif

K
Kexin Zhao 已提交
53 54 55
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 62 || PADDLE_CLANG_VER >= 37)
#define PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
56 57
#endif

K
Kexin Zhao 已提交
58
#ifndef PADDLE_ARM
K
Kexin Zhao 已提交
59 60
#include <immintrin.h>
#endif  // PADDLE_ARM
K
Kexin Zhao 已提交
61

K
Kexin Zhao 已提交
62
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
K
Kexin Zhao 已提交
63 64 65

namespace paddle {

K
Kexin Zhao 已提交
66 67 68
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
K
Kexin Zhao 已提交
69
// with CUDA half, ARM float16_t, and Eigen::half data types.
K
Kexin Zhao 已提交
70
struct PADDLE_ALIGN(2) float16 {
K
Kexin Zhao 已提交
71
public:
K
Kexin Zhao 已提交
72
  uint16_t x;
K
Kexin Zhao 已提交
73

74
  // Constructors
K
Kexin Zhao 已提交
75
  HOSTDEVICE inline float16() : x(0) {}
K
Kexin Zhao 已提交
76

K
Kexin Zhao 已提交
77
  HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
K
Kexin Zhao 已提交
78 79

#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
80
  HOSTDEVICE inline explicit float16(const half& h) {
K
Kexin Zhao 已提交
81 82 83 84 85 86 87 88
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&h)->x;
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16

K
Kexin Zhao 已提交
89
  HOSTDEVICE inline explicit float16(const Eigen::half& h) : x(h.x) {}
K
Kexin Zhao 已提交
90

K
Kexin Zhao 已提交
91
#ifdef PADDLE_WITH_NATIVE_FP16
K
Kexin Zhao 已提交
92
  // __fp16 is a native half precision data type for arm cpu,
93
  // float16_t is an alias for __fp16
K
Kexin Zhao 已提交
94 95
  HOSTDEVICE inline explicit float16(const float16_t& h) {
    x = *reinterpret_cast<const uint16_t*>(&h);
K
Kexin Zhao 已提交
96 97 98
  }
#endif

K
Kexin Zhao 已提交
99 100 101 102
  HOSTDEVICE inline explicit float16(float val) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = __float2half(val);
    x = *reinterpret_cast<uint16_t*>(&tmp);
K
Kexin Zhao 已提交
103

104
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
105 106 107
    float32x4_t tmp = vld1q_dup_f32(&val);
    float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
    x = *reinterpret_cast<uint16_t*>(&res);
K
Kexin Zhao 已提交
108

K
Kexin Zhao 已提交
109 110
#elif defined(__F16C__)
    x = _cvtss_sh(val, 0);
K
Kexin Zhao 已提交
111

K
Kexin Zhao 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;
K
Kexin Zhao 已提交
129

K
Kexin Zhao 已提交
130
#endif
K
Kexin Zhao 已提交
131 132
  }

K
Kexin Zhao 已提交
133
  HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
K
Kexin Zhao 已提交
134

K
Kexin Zhao 已提交
135 136 137
  template <class T>
  HOSTDEVICE inline explicit float16(const T& val)
      : x(float16(static_cast<float>(val)).x) {}
K
Kexin Zhao 已提交
138

K
Kexin Zhao 已提交
139
  HOSTDEVICE inline float16& operator=(const float16& rhs) {
K
Kexin Zhao 已提交
140 141 142 143
    x = rhs.x;
    return *this;
  }

144
// Assignment operators
K
Kexin Zhao 已提交
145
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
146
  HOSTDEVICE inline float16& operator=(const half& rhs) {
K
Kexin Zhao 已提交
147 148 149 150 151 152 153 154 155
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(&rhs)->x;
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

K
Kexin Zhao 已提交
156
  HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
K
Kexin Zhao 已提交
157 158 159 160
    x = rhs.x;
    return *this;
  }

K
Kexin Zhao 已提交
161 162 163
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
    x = *reinterpret_cast<const uint16_t*>(&rhs);
K
Kexin Zhao 已提交
164 165 166 167
    return *this;
  }
#endif

K
Kexin Zhao 已提交
168
  HOSTDEVICE inline float16& operator=(bool b) {
K
Kexin Zhao 已提交
169 170 171 172
    x = b ? 0x3c00 : 0;
    return *this;
  }

K
Kexin Zhao 已提交
173 174
  HOSTDEVICE inline float16& operator=(int8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
175
    return *this;
K
Kexin Zhao 已提交
176 177
  }

K
Kexin Zhao 已提交
178 179
  HOSTDEVICE inline float16& operator=(uint8_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
180 181 182
    return *this;
  }

K
Kexin Zhao 已提交
183 184
  HOSTDEVICE inline float16& operator=(int16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
185 186 187
    return *this;
  }

K
Kexin Zhao 已提交
188 189
  HOSTDEVICE inline float16& operator=(uint16_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
190 191 192
    return *this;
  }

K
Kexin Zhao 已提交
193 194
  HOSTDEVICE inline float16& operator=(int32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
195 196 197
    return *this;
  }

K
Kexin Zhao 已提交
198 199
  HOSTDEVICE inline float16& operator=(uint32_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
200 201 202
    return *this;
  }

K
Kexin Zhao 已提交
203 204
  HOSTDEVICE inline float16& operator=(int64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
205 206 207
    return *this;
  }

K
Kexin Zhao 已提交
208 209
  HOSTDEVICE inline float16& operator=(uint64_t val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
210 211 212
    return *this;
  }

K
Kexin Zhao 已提交
213 214
  HOSTDEVICE inline float16& operator=(float val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
215 216 217
    return *this;
  }

K
Kexin Zhao 已提交
218 219
  HOSTDEVICE inline float16& operator=(double val) {
    x = float16(val).x;
K
Kexin Zhao 已提交
220
    return *this;
K
Kexin Zhao 已提交
221
  }
K
Kexin Zhao 已提交
222

223
// Conversion opertors
K
Kexin Zhao 已提交
224
#ifdef PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
225
  HOSTDEVICE inline explicit operator half() const {
K
Kexin Zhao 已提交
226 227 228 229 230 231 232 233 234 235 236
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
237

K
Kexin Zhao 已提交
238
  HOSTDEVICE inline explicit operator Eigen::half() const {
K
Kexin Zhao 已提交
239 240 241 242 243
    Eigen::half h;
    h.x = x;
    return h;
  }

K
Kexin Zhao 已提交
244 245 246
#ifdef PADDLE_WITH_NATIVE_FP16
  HOSTDEVICE inline explicit operator float16_t() const {
    return *reinterpret_cast<const float16_t*>(this);
K
Kexin Zhao 已提交
247 248 249
  }
#endif

K
Kexin Zhao 已提交
250 251 252 253 254
  HOSTDEVICE inline explicit operator float() const {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

255
#elif defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
256 257
    float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
    return vgetq_lane_f32(vcvt_f32_f16(res), 0);
K
Kexin Zhao 已提交
258

K
Kexin Zhao 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
K
Kexin Zhao 已提交
282 283
  }

K
Kexin Zhao 已提交
284 285 286 287
  HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(float(*this));
K
Kexin Zhao 已提交
288 289
  }

K
Kexin Zhao 已提交
290 291
  HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(float(*this));
K
Kexin Zhao 已提交
292 293
  }

K
Kexin Zhao 已提交
294 295
  HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(float(*this));
K
Kexin Zhao 已提交
296 297
  }

K
Kexin Zhao 已提交
298 299
  HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(float(*this));
K
Kexin Zhao 已提交
300 301
  }

K
Kexin Zhao 已提交
302 303
  HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(float(*this));
K
Kexin Zhao 已提交
304 305
  }

K
Kexin Zhao 已提交
306 307
  HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(float(*this));
K
Kexin Zhao 已提交
308 309
  }

K
Kexin Zhao 已提交
310 311
  HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(float(*this));
K
Kexin Zhao 已提交
312 313
  }

K
Kexin Zhao 已提交
314 315
  HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(float(*this));
K
Kexin Zhao 已提交
316 317
  }

K
Kexin Zhao 已提交
318 319
  HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(float(*this));
K
Kexin Zhao 已提交
320
  }
K
Kexin Zhao 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1)
                                  << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
K
Kexin Zhao 已提交
351 352
};

K
Kexin Zhao 已提交
353 354 355 356 357
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
358 359
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000

K
Kexin Zhao 已提交
360
DEVICE inline half operator+(const half& a, const half& b) {
361
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
362
  return __hadd(a, b);
363 364 365 366
#else
  float res = float(float16(a)) + float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
367 368 369
}

DEVICE inline half operator-(const half& a, const half& b) {
370
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
371
  return __hsub(a, b);
372 373 374 375
#else
  float res = float(float16(a)) - float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
376 377 378
}

DEVICE inline half operator*(const half& a, const half& b) {
379
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
380
  return __hmul(a, b);
381 382 383 384
#else
  float res = float(float16(a)) * float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
385 386 387
}

DEVICE inline half operator/(const half& a, const half& b) {
388
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
K
Kexin Zhao 已提交
389 390 391
  float num = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
392 393 394 395
#else
  float res = float(float16(a)) / float(float16(b));
  return half(float16(res));
#endif
K
Kexin Zhao 已提交
396 397
}

398 399 400 401 402 403 404 405
DEVICE inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
  float res = -float(float16(a));
  return half(float16(res));
#endif
}
K
Kexin Zhao 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427

DEVICE inline half& operator+=(half& a, const half& b) {
  a = a + b;
  return a;
}

DEVICE inline half& operator-=(half& a, const half& b) {
  a = a - b;
  return a;
}

DEVICE inline half& operator*=(half& a, const half& b) {
  a = a * b;
  return a;
}

DEVICE inline half& operator/=(half& a, const half& b) {
  a = a / b;
  return a;
}

DEVICE inline bool operator==(const half& a, const half& b) {
428
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
429
  return __heq(a, b);
430 431 432
#else
  return float(float16(a)) == float(float16(b));
#endif
K
Kexin Zhao 已提交
433 434 435
}

DEVICE inline bool operator!=(const half& a, const half& b) {
436
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
437
  return __hne(a, b);
438 439 440
#else
  return float(float16(a)) != float(float16(b));
#endif
K
Kexin Zhao 已提交
441 442 443
}

DEVICE inline bool operator<(const half& a, const half& b) {
444
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
445
  return __hlt(a, b);
446 447 448
#else
  return float(float16(a)) < float(float16(b));
#endif
K
Kexin Zhao 已提交
449 450 451
}

DEVICE inline bool operator<=(const half& a, const half& b) {
452
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
453
  return __hle(a, b);
454 455 456
#else
  return float(float16(a)) <= float(float16(b));
#endif
K
Kexin Zhao 已提交
457 458 459
}

DEVICE inline bool operator>(const half& a, const half& b) {
460
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
461
  return __hgt(a, b);
462 463 464
#else
  return float(float16(a)) > float(float16(b));
#endif
K
Kexin Zhao 已提交
465 466 467
}

DEVICE inline bool operator>=(const half& a, const half& b) {
468
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
K
Kexin Zhao 已提交
469
  return __hge(a, b);
470 471 472
#else
  return float(float16(a)) >= float(float16(b));
#endif
K
Kexin Zhao 已提交
473 474
}

475
#endif  // PADDLE_CUDA_FP16
K
Kexin Zhao 已提交
476 477

// Arithmetic operators on ARMv8.2-A CPU
478
#if defined(PADDLE_WITH_NATIVE_FP16)
K
Kexin Zhao 已提交
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
HOST inline float16 operator+(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fadd h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator-(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fsub h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator*(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fmul h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}

HOST inline float16 operator/(const float16& a, const float16& b) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fdiv h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0", "v1");
  return res;
}
K
Kexin Zhao 已提交
546

K
Kexin Zhao 已提交
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673
HOST inline float16 operator-(const float16& a) {
  float16 res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "fneg h0, h0\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [res_ptr] "r"(&(res.x))
      :  // clobbers
      "memory", "v0");
  return res;
}

HOST inline float16& operator+=(float16& a, const float16& b) {
  a = a + b;
  return a;
}

HOST inline float16& operator-=(float16& a, const float16& b) {
  a = a - b;
  return a;
}

HOST inline float16& operator*=(float16& a, const float16& b) {
  a = a * b;
  return a;
}

HOST inline float16& operator/=(float16& a, const float16& b) {
  a = a / b;
  return a;
}

HOST inline bool operator==(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmeq h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator!=(const float16& a, const float16& b) {
  return !(a == b);
}

HOST inline bool operator<(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator<=(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v1.h}[0], [%[a_ptr]]\n"
      "ld1 {v0.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator>(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmgt h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

HOST inline bool operator>=(const float16& a, const float16& b) {
  uint16_t res;
  asm volatile(
      "ld1 {v0.h}[0], [%[a_ptr]]\n"
      "ld1 {v1.h}[0], [%[b_ptr]]\n"
      "fcmge h0, h0, h1\n"
      "st1 {v0.h}[0], [%[res_ptr]]\n"
      :  // outputs
      :  // inputs
      [a_ptr] "r"(&(a.x)),
      [b_ptr] "r"(&(b.x)),
      [res_ptr] "r"(&res)
      :  // clobbers
      "memory", "v0", "v1");
  return (res & 0xffff) != 0;
}

// Arithmetic operators, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
674 675 676
  return float16(float(a) + float(b));
}

K
Kexin Zhao 已提交
677
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
678 679 680
  return float16(float(a) - float(b));
}

K
Kexin Zhao 已提交
681
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
682 683 684
  return float16(float(a) * float(b));
}

K
Kexin Zhao 已提交
685
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
686 687 688
  return float16(float(a) / float(b));
}

K
Kexin Zhao 已提交
689
HOSTDEVICE inline float16 operator-(const float16& a) {
K
Kexin Zhao 已提交
690 691 692 693 694
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
}

K
Kexin Zhao 已提交
695
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
696 697 698 699
  a = float16(float(a) + float(b));
  return a;
}

K
Kexin Zhao 已提交
700
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
701 702 703 704
  a = float16(float(a) - float(b));
  return a;
}

K
Kexin Zhao 已提交
705
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
706 707 708 709
  a = float16(float(a) * float(b));
  return a;
}

K
Kexin Zhao 已提交
710
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
K
Kexin Zhao 已提交
711 712 713 714
  a = float16(float(a) / float(b));
  return a;
}

K
Kexin Zhao 已提交
715
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
716 717 718
  return float(a) == float(b);
}

K
Kexin Zhao 已提交
719
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
720 721 722
  return float(a) != float(b);
}

K
Kexin Zhao 已提交
723
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
724 725 726
  return float(a) < float(b);
}

K
Kexin Zhao 已提交
727
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
728 729 730
  return float(a) <= float(b);
}

K
Kexin Zhao 已提交
731
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
732 733 734
  return float(a) > float(b);
}

K
Kexin Zhao 已提交
735
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
K
Kexin Zhao 已提交
736 737
  return float(a) >= float(b);
}
K
Kexin Zhao 已提交
738
#endif
K
Kexin Zhao 已提交
739
}  // namespace paddle