float16.h 18.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef CINN_COMMON_FLOAT16_H
#define CINN_COMMON_FLOAT16_H

#ifdef __cplusplus
#pragma once
#endif  // __cplusplus

#if defined(_M_X64) || defined(__x86_64__) || defined(_M_IX86) || defined(__i386__)
#define __CINN_x86__
#include <immintrin.h>
#endif

#include <stdint.h>

#include <cmath>

#ifdef CINN_WITH_CUDA
#include <cuda.h>

#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) && CUDA_VERSION >= 7050
#define CINN_CUDA_FP16
#include <cuda_fp16.h>

#define CUDA_ARCH_FP16_SUPPORTED(CUDA_ARCH) (CUDA_ARCH >= 600)
#endif  // __CUDACC__
#endif  // CINN_WITH_CUDA

#ifdef __cplusplus
#ifndef _WIN32
#define CINN_ALIGN(x) __attribute__((aligned(x)))
#else  // _WIN32
#define CINN_ALIGN(x) __declspec(align(x))
#endif  // _WIN32

#else  // __cplusplus
#define CINN_ALIGN(x)
#endif  // __cplusplus

// The `HOST` macro definition is not used here, it has a potential
// conflict with the enumeration `kHOST` representing the backend.
#ifndef __host__
#define __host__
#endif
#ifndef __device__
#define __device__
#endif

#ifdef __cplusplus
namespace cinn {
namespace common {
#endif  // __cplusplus

// Use CINN_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible
// with CUDA half
struct CINN_ALIGN(2) float16 {
  uint16_t x;

#ifdef __cplusplus
  // The following defaulted special class member functions
  // are added to make float16 pass the std::is_trivial test
  float16()                 = default;
  float16(const float16& o) = default;
  float16& operator=(const float16& o) = default;
  float16(float16&& o)                 = default;
  float16& operator=(float16&& o) = default;
  ~float16()                      = default;

// Constructors
#ifdef CINN_CUDA_FP16
  __host__ __device__ inline explicit float16(const half& h) {
#if (CUDA_VERSION >= 9000)
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&h))->x;
#else
    x = h.x;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // CINN_CUDA_FP16

  __host__ __device__ inline explicit float16(float val) {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
    half tmp = __float2half(val);
    x        = *reinterpret_cast<uint16_t*>(&tmp);

#elif defined(__F16C__) && defined(__CINN_x86__)
    x = _cvtss_sh(val, 0);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v, s;
    v.f           = val;
    uint32_t sign = v.si & sigN;
    v.si ^= sign;
    sign >>= shiftSign;  // logical shift
    s.si = mulN;
    s.si = s.f * v.f;  // correct subnormals
    v.si ^= (s.si ^ v.si) & -(minN > v.si);
    v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
    v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
    v.ui >>= shift;  // logical shift
    v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
    v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
    x = v.ui | sign;

#endif
  }

  __host__ __device__ inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}

  template <class T>
  __host__ __device__ inline explicit float16(const T& val) : x(float16(static_cast<float>(val)).x) {}

// Assignment operators
#ifdef CINN_CUDA_FP16
  __host__ __device__ inline float16& operator=(const half& rhs) {
#if CUDA_VERSION >= 9000
    x = reinterpret_cast<__half_raw*>(const_cast<half*>(&rhs))->x;
#else
    x = rhs.x;
#endif
    return *this;
  }
#endif

  __host__ __device__ inline float16& operator=(bool b) {
    x = b ? 0x3c00 : 0;
    return *this;
  }

  __host__ __device__ inline float16& operator=(int8_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(uint8_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(int16_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(uint16_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(int32_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(uint32_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(int64_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(uint64_t val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(float val) {
    x = float16(val).x;
    return *this;
  }

  __host__ __device__ inline float16& operator=(double val) {
    x = float16(val).x;
    return *this;
  }

// Conversion opertors
#ifdef CINN_CUDA_FP16
  __host__ __device__ inline half to_half() const {
#if CUDA_VERSION >= 9000
    __half_raw h;
    h.x = x;
    return half(h);
#else
    half h;
    h.x = x;
    return h;
#endif  // CUDA_VERSION >= 9000
  }
#endif  // CINN_CUDA_FP16

  __host__ __device__ inline operator float() const {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
    half tmp = *reinterpret_cast<const half*>(this);
    return __half2float(tmp);

#elif defined(__F16C__)
    return _cvtsh_ss(this->x);

#else
    // Conversion routine adapted from
    // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
    Bits v;
    v.ui         = this->x;
    int32_t sign = v.si & sigC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;

#endif
  }

  __host__ __device__ inline explicit operator bool() const { return (x & 0x7fff) != 0; }

  __host__ __device__ inline explicit operator int8_t() const { return static_cast<int8_t>(static_cast<float>(*this)); }

  __host__ __device__ inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator int16_t() const {
    return static_cast<int16_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator int32_t() const {
    return static_cast<int32_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator int64_t() const {
    return static_cast<int64_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(static_cast<float>(*this));
  }

  __host__ __device__ inline operator double() const { return static_cast<double>(static_cast<float>(*this)); }

 private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static const int shift     = 13;
  static const int shiftSign = 16;

  static const int32_t infN = 0x7F800000;
  static const int32_t maxN = 0x477FE000;  // max flt16 as flt32
  static const int32_t minN = 0x38800000;  // min flt16 normal as flt32
  static const int32_t sigN = 0x80000000;  // sign bit

  static constexpr int32_t infC = infN >> shift;
  static constexpr int32_t nanN = (infC + 1) << shift;  // minimum flt16 nan as float32
  static constexpr int32_t maxC = maxN >> shift;
  static constexpr int32_t minC = minN >> shift;
  static constexpr int32_t sigC = sigN >> shiftSign;

  static const int32_t mulN = 0x52000000;  // (1 << 23) / minN
  static const int32_t mulC = 0x33800000;  // minN / (1 << (23 - shift))
  static const int32_t subC = 0x003FF;     // max flt32 subnormal downshifted
  static const int32_t norC = 0x00400;     // min flt32 normal downshifted

  static constexpr int32_t maxD = infC - maxC - 1;
  static constexpr int32_t minD = minC - subC - 1;
#endif  // __cplusplus
};

struct CINN_ALIGN(32) float8 {
  float x, y, z, w, v, u, t, s;
};

struct CINN_ALIGN(16) half8 {
  float16 x, y, z, w, v, u, t, s;
};

struct CINN_ALIGN(8) half4 {
  float16 x, y, z, w;
};

#ifdef __cplusplus
// Arithmetic operators on GPU
// CUDA 9.0 provides built-in arithmetic operators for half while
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
// ROCM has built-in arithmetic operators as not defined
// __HIP_NO_HALF_OPERATORS__
#if defined(CINN_CUDA_FP16) && CUDA_VERSION < 9000
__device__ inline half operator+(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hadd(a, b);
#else
  float res = static_cast<float>(float16(a)) + static_cast<float>(float16(b));
  return float16(res).to_half();
#endif
}

__device__ inline half operator-(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hsub(a, b);
#else
  float res = static_cast<float>(float16(a)) - static_cast<float>(float16(b));
  return float16(res).to_half();
#endif
}

__device__ inline half operator*(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hmul(a, b);
#else
  float res = static_cast<float>(float16(a)) * static_cast<float>(float16(b));
  return float16(res).to_half();
#endif
}

__device__ inline half operator/(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  float num   = __half2float(a);
  float denom = __half2float(b);
  return __float2half(num / denom);
#else
  float res = static_cast<float>(float16(a)) / static_cast<float>(float16(b));
  return float16(res).to_half();
#endif
}

__device__ inline half operator-(const half& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hneg(a);
#else
  float res = -static_cast<float>(float16(a));
  return float16(res).to_half();
#endif
}

__device__ inline half& operator+=(half& a, const half& b) {  // NOLINT
  a = a + b;
  return a;
}

__device__ inline half& operator-=(half& a, const half& b) {  // NOLINT
  a = a - b;
  return a;
}

__device__ inline half& operator*=(half& a, const half& b) {  // NOLINT
  a = a * b;
  return a;
}

__device__ inline half& operator/=(half& a, const half& b) {  // NOLINT
  a = a / b;
  return a;
}

__device__ inline bool operator==(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __heq(a, b);
#else
  return static_cast<float>(float16(a)) == static_cast<float>(float16(b));
#endif
}

__device__ inline bool operator!=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hne(a, b);
#else
  return static_cast<float>(float16(a)) != static_cast<float>(float16(b));
#endif
}

__device__ inline bool operator<(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hlt(a, b);
#else
  return static_cast<float>(float16(a)) < static_cast<float>(float16(b));
#endif
}

__device__ inline bool operator<=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hle(a, b);
#else
  return static_cast<float>(float16(a)) <= static_cast<float>(float16(b));
#endif
}

__device__ inline bool operator>(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hgt(a, b);
#else
  return static_cast<float>(float16(a)) > static_cast<float>(float16(b));
#endif
}

__device__ inline bool operator>=(const half& a, const half& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hge(a, b);
#else
  return static_cast<float>(float16(a)) >= static_cast<float>(float16(b));
#endif
}

#endif  // CINN_CUDA_FP16

// Arithmetic operators for float16 on GPU
__host__ __device__ inline float16 operator+(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return float16(__hadd(a.to_half(), b.to_half()));
#else
  return float16(static_cast<float>(a) + static_cast<float>(b));
#endif
}

__host__ __device__ inline float16 operator-(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return float16(__hsub(a.to_half(), b.to_half()));
#else
  return float16(static_cast<float>(a) - static_cast<float>(b));
#endif
}

__host__ __device__ inline float16 operator*(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return float16(__hmul(a.to_half(), b.to_half()));
#else
  return float16(static_cast<float>(a) * static_cast<float>(b));
#endif
}

__host__ __device__ inline float16 operator/(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  // TODO(kexinzhao): check which cuda version starts to support __hdiv
  float num   = __half2float(a.to_half());
  float denom = __half2float(b.to_half());
  return float16(num / denom);
#else
  return float16(static_cast<float>(a) / static_cast<float>(b));
#endif
}

__host__ __device__ inline float16 operator-(const float16& a) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return float16(__hneg(a.to_half()));
#else
  float16 res;
  res.x = a.x ^ 0x8000;
  return res;
#endif
}

__host__ __device__ inline float16& operator+=(float16& a, const float16& b) {  // NOLINT
  a = a + b;
  return a;
}

__host__ __device__ inline float16& operator-=(float16& a, const float16& b) {  // NOLINT
  a = a - b;
  return a;
}

__host__ __device__ inline float16& operator*=(float16& a, const float16& b) {  // NOLINT
  a = a * b;
  return a;
}

__host__ __device__ inline float16& operator/=(float16& a, const float16& b) {  // NOLINT
  a = a / b;
  return a;
}

__host__ __device__ inline bool operator==(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __heq(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) == static_cast<float>(b);
#endif
}

__host__ __device__ inline bool operator!=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hne(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) != static_cast<float>(b);
#endif
}

__host__ __device__ inline bool operator<(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hlt(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) < static_cast<float>(b);
#endif
}

__host__ __device__ inline bool operator<=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hle(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) <= static_cast<float>(b);
#endif
}

__host__ __device__ inline bool operator>(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hgt(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) > static_cast<float>(b);
#endif
}

__host__ __device__ inline bool operator>=(const float16& a, const float16& b) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hge(a.to_half(), b.to_half());
#else
  return static_cast<float>(a) >= static_cast<float>(b);
#endif
}
#endif  // __cplusplus

__host__ __device__ inline float16 raw_uint16_to_float16(uint16_t a) {
  float16 res;
  res.x = a;
  return res;
}

__host__ __device__ inline bool(isnan)(const float16& a) {
#if defined(CINN_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
  return __hisnan(a.to_half());
#else
  return (a.x & 0x7fff) > 0x7c00;
#endif
}

__host__ __device__ inline bool(isinf)(const float16& a) { return (a.x & 0x7fff) == 0x7c00; }

__host__ __device__ inline bool(isfinite)(const float16& a) { return !((isnan)(a)) && !((isinf)(a)); }

__host__ __device__ inline float16(abs)(const float16& a) {
#if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
  return float16(__habs(a.to_half()));
#else
  return float16(fabsf(float(a)));
#endif
}

__host__ __device__ inline float16(log)(const float16& a) { return float16(std::log(static_cast<float>(a))); }

#ifdef __cplusplus
}  // namespace common
}  // namespace cinn
#endif  // __cplusplus

#if defined(__cplusplus) && defined(CINN_CUDA_FP16)
__device__ inline cinn::common::float16 __shfl_sync(unsigned mask,
                                                    cinn::common::float16 var,
                                                    int srcLane,
                                                    int width = warpSize) {
  return cinn::common::float16(__shfl_sync(mask, var.to_half(), srcLane, width));
}

__device__ inline cinn::common::float16 __shfl_up_sync(unsigned mask,
                                                       cinn::common::float16 var,
                                                       unsigned int delta,
                                                       int width = warpSize) {
  return cinn::common::float16(__shfl_up_sync(mask, var.to_half(), delta, width));
}

__device__ inline cinn::common::float16 __shfl_down_sync(unsigned mask,
                                                         cinn::common::float16 var,
                                                         unsigned int delta,
                                                         int width = warpSize) {
  return cinn::common::float16(__shfl_down_sync(mask, var.to_half(), delta, width));
}

__device__ inline cinn::common::float16 __shfl_xor_sync(unsigned mask,
                                                        cinn::common::float16 var,
                                                        int laneMask,
                                                        int width = warpSize) {
  return cinn::common::float16(__shfl_xor_sync(mask, var.to_half(), laneMask, width));
}

__host__ __device__ inline cinn::common::float16 max(const cinn::common::float16& a, const cinn::common::float16& b) {
  return a > b ? a : b;
}
__host__ __device__ inline cinn::common::float16 min(const cinn::common::float16& a, const cinn::common::float16& b) {
  return a < b ? a : b;
}
#endif  // __cplusplus && CINN_CUDA_FP16

#endif  // CINN_COMMON_FLOAT16_H