complex64.h 17.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <stdint.h>
#include <limits>
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif

#ifdef PADDLE_WITH_CUDA
#include <cuComplex.h>
#include <thrust/complex.h>
#endif  // PADDLE_WITH_CUDA

30 31 32 33 34
#ifdef PADDLE_WITH_HIP
#include <hip/hip_complex.h>
#include <thrust/complex.h>  // NOLINT
#endif

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
#include <cstring>

#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace Eigen {
template <typename T>
struct NumTraits;
}  // namespace Eigen

namespace paddle {
namespace platform {

struct PADDLE_ALIGN(8) complex64 {
 public:
  float real;
  float imag;

  complex64() = default;
  complex64(const complex64& o) = default;
  complex64& operator=(const complex64& o) = default;
  complex64(complex64&& o) = default;
  complex64& operator=(complex64&& o) = default;
  ~complex64() = default;

  HOSTDEVICE complex64(float real, float imag) : real(real), imag(imag) {}
62
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
63 64 65 66 67 68 69 70 71 72

  HOSTDEVICE inline explicit complex64(const thrust::complex<float>& c) {
    real = c.real();
    imag = c.imag();
  }

  HOSTDEVICE inline explicit operator thrust::complex<float>() const {
    return thrust::complex<float>(real, imag);
  }

73 74 75 76 77
#ifdef PADDLE_WITH_HIP
  HOSTDEVICE inline explicit operator hipFloatComplex() const {
    return make_hipFloatComplex(real, imag);
  }
#else
78 79 80
  HOSTDEVICE inline explicit operator cuFloatComplex() const {
    return make_cuFloatComplex(real, imag);
  }
81
#endif
82 83
#endif

84 85 86 87 88 89 90 91 92 93
  HOSTDEVICE complex64(const float& val) : real(val), imag(0) {}
  HOSTDEVICE complex64(const double& val)
      : real(static_cast<float>(val)), imag(0) {}
  HOSTDEVICE complex64(const int& val)
      : real(static_cast<float>(val)), imag(0) {}
  HOSTDEVICE complex64(const int64_t& val)
      : real(static_cast<float>(val)), imag(0) {}
  HOSTDEVICE complex64(const complex128& val)
      : real(static_cast<float>(val.real)),
        imag(static_cast<float>(val.imag)) {}
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

  HOSTDEVICE inline explicit operator std::complex<float>() {
    return static_cast<std::complex<float>>(std::complex<float>(real, imag));
  }

  template <class T>
  HOSTDEVICE inline explicit complex64(const T& val)
      : real(complex64(static_cast<float>(val)).real) {}

  HOSTDEVICE complex64(const std::complex<float> val)
      : real(val.real()), imag(val.imag()) {}

  HOSTDEVICE inline complex64& operator=(bool b) {
    real = b ? 1 : 0;
    imag = 0;
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(int8_t val) {
    real = static_cast<float>(val);
114
    imag = 0;
115 116 117 118 119
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(uint8_t val) {
    real = static_cast<float>(val);
120
    imag = 0;
121 122 123 124 125
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(int16_t val) {
    real = static_cast<float>(val);
126
    imag = 0;
127 128 129 130 131
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(uint16_t val) {
    real = static_cast<float>(val);
132
    imag = 0;
133 134 135 136 137
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(int32_t val) {
    real = static_cast<float>(val);
138
    imag = 0;
139 140 141 142 143
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(uint32_t val) {
    real = static_cast<float>(val);
144
    imag = 0;
145 146 147 148 149
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(int64_t val) {
    real = static_cast<float>(val);
150
    imag = 0;
151 152 153 154 155
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(uint64_t val) {
    real = static_cast<float>(val);
156
    imag = 0;
157 158 159 160 161
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(float val) {
    real = val;
162
    imag = 0;
163 164 165 166 167
    return *this;
  }

  HOSTDEVICE inline complex64& operator=(double val) {
    real = static_cast<float>(val);
168
    imag = 0;
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    return *this;
  }

  HOSTDEVICE inline operator float() const { return this->real; }

  HOSTDEVICE inline explicit operator bool() const {
    return static_cast<bool>(this->real) || static_cast<bool>(this->imag);
  }

  HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(this->real);
  }

  HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(this->real);
  }

  HOSTDEVICE inline operator complex128() const {
    return complex128(static_cast<double>(this->real),
                      static_cast<double>(this->imag));
  }
};

HOSTDEVICE inline complex64 operator+(const complex64& a, const complex64& b) {
221
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
222 223 224 225 226 227 228 229
  return complex64(thrust::complex<float>(a.real, a.imag) +
                   thrust::complex<float>(b.real, b.imag));
#else
  return complex64(a.real + b.real, a.imag + b.imag);
#endif
}

HOSTDEVICE inline complex64 operator-(const complex64& a, const complex64& b) {
230
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
231 232 233 234 235 236 237 238
  return complex64(thrust::complex<float>(a.real, a.imag) -
                   thrust::complex<float>(b.real, b.imag));
#else
  return complex64(a.real - b.real, a.imag - b.imag);
#endif
}

HOSTDEVICE inline complex64 operator*(const complex64& a, const complex64& b) {
239
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
240 241 242 243 244 245 246 247 248
  return complex64(thrust::complex<float>(a.real, a.imag) *
                   thrust::complex<float>(b.real, b.imag));
#else
  return complex64(a.real * b.real - a.imag * b.imag,
                   a.imag * b.real + b.imag * a.real);
#endif
}

HOSTDEVICE inline complex64 operator/(const complex64& a, const complex64& b) {
249
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
250 251 252 253 254 255 256 257 258 259
  return complex64(thrust::complex<float>(a.real, a.imag) /
                   thrust::complex<float>(b.real, b.imag));
#else
  float denominator = b.real * b.real + b.imag * b.imag;
  return complex64((a.real * b.real + a.imag * b.imag) / denominator,
                   (a.imag * b.real - a.real * b.imag) / denominator);
#endif
}

HOSTDEVICE inline complex64 operator-(const complex64& a) {
260
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
261 262 263 264 265 266 267 268 269 270 271
  return complex64(-thrust::complex<float>(a.real, a.imag));
#else
  complex64 res;
  res.real = -a.real;
  res.imag = -a.imag;
  return res;
#endif
}

HOSTDEVICE inline complex64& operator+=(complex64& a,  // NOLINT
                                        const complex64& b) {
272
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
273 274 275 276 277 278 279 280 281 282 283 284
  a = complex64(thrust::complex<float>(a.real, a.imag) +=
                thrust::complex<float>(b.real, b.imag));
  return a;
#else
  a.real += b.real;
  a.imag += b.imag;
  return a;
#endif
}

HOSTDEVICE inline complex64& operator-=(complex64& a,  // NOLINT
                                        const complex64& b) {
285
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
286 287 288 289 290 291 292 293 294 295 296 297
  a = complex64(thrust::complex<float>(a.real, a.imag) -=
                thrust::complex<float>(b.real, b.imag));
  return a;
#else
  a.real -= b.real;
  a.imag -= b.imag;
  return a;
#endif
}

HOSTDEVICE inline complex64& operator*=(complex64& a,  // NOLINT
                                        const complex64& b) {
298
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
299 300 301 302 303 304 305 306 307 308 309 310
  a = complex64(thrust::complex<float>(a.real, a.imag) *=
                thrust::complex<float>(b.real, b.imag));
  return a;
#else
  a.real = a.real * b.real - a.imag * b.imag;
  a.imag = a.imag * b.real + b.imag * a.real;
  return a;
#endif
}

HOSTDEVICE inline complex64& operator/=(complex64& a,  // NOLINT
                                        const complex64& b) {
311
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
  a = complex64(thrust::complex<float>(a.real, a.imag) /=
                thrust::complex<float>(b.real, b.imag));
  return a;
#else
  float denominator = b.real * b.real + b.imag * b.imag;
  a.real = (a.real * b.real + a.imag * b.imag) / denominator;
  a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
  return a;
#endif
}

HOSTDEVICE inline complex64 raw_uint16_to_complex64(uint16_t a) {
  complex64 res;
  res.real = a;
  return res;
}

HOSTDEVICE inline bool operator==(const complex64& a, const complex64& b) {
  return a.real == b.real && a.imag == b.imag;
}

HOSTDEVICE inline bool operator!=(const complex64& a, const complex64& b) {
  return a.real != b.real || a.imag != b.imag;
}

HOSTDEVICE inline bool operator<(const complex64& a, const complex64& b) {
  return static_cast<float>(a.real) < static_cast<float>(b.real);
}

HOSTDEVICE inline bool operator<=(const complex64& a, const complex64& b) {
  return static_cast<float>(a.real) <= static_cast<float>(b.real);
}

HOSTDEVICE inline bool operator>(const complex64& a, const complex64& b) {
  return static_cast<float>(a.real) > static_cast<float>(b.real);
}

HOSTDEVICE inline bool operator>=(const complex64& a, const complex64& b) {
  return static_cast<float>(a.real) >= static_cast<float>(b.real);
}

HOSTDEVICE inline bool(isnan)(const complex64& a) {
#if defined(__CUDA_ARCH__)
355
  // __isnanf not supported on HIP platform
356 357 358 359 360 361 362 363
  return __isnanf(a.real) || __isnanf(a.imag);
#else
  return std::isnan(a.real) || std::isnan(a.imag);
#endif
}

HOSTDEVICE inline bool(isinf)(const complex64& a) {
#if defined(__CUDA_ARCH__)
364
  // __isinff not supported on HIP platform
365 366 367 368 369 370 371 372 373 374 375
  return __isinff(a.real) || __isinff(a.imag);
#else
  return std::isinf(a.real) || std::isinf(a.imag);
#endif
}

HOSTDEVICE inline bool(isfinite)(const complex64& a) {
  return !((isnan)(a)) && !((isinf)(a));
}

HOSTDEVICE inline float(abs)(const complex64& a) {
376
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
377 378
  return complex64(thrust::abs(thrust::complex<float>(a.real, a.imag)));
#else
379
  return std::abs(std::complex<float>(a.real, a.imag));
380 381 382 383
#endif
}

HOSTDEVICE inline complex64(pow)(const complex64& a, const complex64& b) {
384
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
385 386 387 388 389 390 391 392
  return complex64(thrust::pow(thrust::complex<float>(a.real, a.imag),
                               thrust::complex<float>(b.real, b.imag)));
#else
  return std::pow(std::complex<float>(a), std::complex<float>(b));
#endif
}

HOSTDEVICE inline complex64(sqrt)(const complex64& a) {
393
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
394 395 396 397 398 399 400
  return complex64(thrust::sqrt(thrust::complex<float>(a.real, a.imag)));
#else
  return std::sqrt(std::complex<float>(a));
#endif
}

HOSTDEVICE inline complex64(tanh)(const complex64& a) {
401
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
402 403 404 405 406 407 408
  return complex64(thrust::tanh(thrust::complex<float>(a.real, a.imag)));
#else
  return std::tanh(std::complex<float>(a));
#endif
}

HOSTDEVICE inline complex64(log)(const complex64& a) {
409
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
  return complex64(thrust::log(thrust::complex<float>(a.real, a.imag)));
#else
  return std::log(std::complex<float>(a));
#endif
}

inline std::ostream& operator<<(std::ostream& os, const complex64& a) {
  os << "real:" << a.real << " imag:" << a.imag;
  return os;
}

}  // namespace platform
}  // namespace paddle

namespace std {

template <>
struct is_pod<paddle::platform::complex64> {
  static const bool value =
      is_trivial<paddle::platform::complex64>::value &&
      is_standard_layout<paddle::platform::complex64>::value;
};

template <>
struct is_floating_point<paddle::platform::complex64>
    : std::integral_constant<
          bool, std::is_same<paddle::platform::complex64,
                             typename std::remove_cv<
                                 paddle::platform::complex64>::type>::value> {};
template <>
struct is_signed<paddle::platform::complex64> {
  static const bool value = false;
};

template <>
struct is_unsigned<paddle::platform::complex64> {
  static const bool value = false;
};

inline bool isnan(const paddle::platform::complex64& a) {
  return paddle::platform::isnan(a);
}

inline bool isinf(const paddle::platform::complex64& a) {
  return paddle::platform::isinf(a);
}

template <>
struct numeric_limits<paddle::platform::complex64> {
  static const bool is_specialized = false;
  static const bool is_signed = false;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = false;
  static const bool has_quiet_NaN = false;
  static const bool has_signaling_NaN = false;
  static const float_denorm_style has_denorm = denorm_absent;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_toward_zero;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 0;
  static const int digits10 = 0;
  static const int max_digits10 = 0;
  static const int radix = 0;
  static const int min_exponent = 0;
  static const int min_exponent10 = 0;
  static const int max_exponent = 0;
  static const int max_exponent10 = 0;
  static const bool traps = false;
  static const bool tinyness_before = false;

  static paddle::platform::complex64(min)() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 lowest() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64(max)() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 epsilon() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 round_error() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 infinity() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 quiet_NaN() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 signaling_NaN() {
    return paddle::platform::complex64(0.0, 0.0);
  }
  static paddle::platform::complex64 denorm_min() {
    return paddle::platform::complex64(0.0, 0.0);
  }
};

}  // namespace std
namespace Eigen {

using complex64 = paddle::platform::complex64;

template <>
struct NumTraits<complex64> : GenericNumTraits<std::complex<float>> {
  typedef float Real;
  typedef typename NumTraits<float>::Literal Literal;
  enum {
    IsComplex = 1,
    RequireInitialization = NumTraits<float>::RequireInitialization,
    ReadCost = 2 * NumTraits<float>::ReadCost,
    AddCost = 2 * NumTraits<Real>::AddCost,
    MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
  };

  EIGEN_DEVICE_FUNC
  static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
  EIGEN_DEVICE_FUNC
  static inline Real dummy_precision() {
    return NumTraits<Real>::dummy_precision();
  }
  EIGEN_DEVICE_FUNC
  static inline int digits10() { return NumTraits<Real>::digits10(); }
};

namespace numext {

template <>
HOSTDEVICE inline bool(isnan)(const complex64& a) {
  return (paddle::platform::isnan)(a);
}

template <>
HOSTDEVICE inline bool(isinf)(const complex64& a) {
  return (paddle::platform::isinf)(a);
}

template <>
HOSTDEVICE inline bool(isfinite)(const complex64& a) {
  return (paddle::platform::isfinite)(a);
}

template <>
HOSTDEVICE inline complex64 exp(const complex64& a) {
  float com = ::expf(a.real);
  float res_real = com * ::cosf(a.imag);
  float res_imag = com * ::sinf(a.imag);
  return complex64(res_real, res_imag);
}

template <>
HOSTDEVICE inline complex64 log(const complex64& a) {
  return paddle::platform::log(a);
}

template <>
HOSTDEVICE inline complex64 tanh(const complex64& a) {
  return paddle::platform::tanh(a);
}

template <>
HOSTDEVICE inline complex64 sqrt(const complex64& a) {
  return paddle::platform::sqrt(a);
}

template <>
HOSTDEVICE inline complex64 ceil(const complex64& a) {
  return complex64(::ceilf(a.real), ::ceilf(a.imag));
}

template <>
HOSTDEVICE inline complex64 floor(const complex64& a) {
  return complex64(::floorf(a.real), ::floor(a.imag));
}

template <>
HOSTDEVICE inline complex64 round(const complex64& a) {
  return complex64(::roundf(a.real), ::roundf(a.imag));
}

template <>
HOSTDEVICE inline complex64 pow(const complex64& a, const complex64& b) {
  return paddle::platform::pow(a, b);
}

template <>
HOSTDEVICE inline float abs(const complex64& a) {
  return paddle::platform::abs(a);
}

}  // namespace numext
}  // namespace Eigen

#define MKL_Complex8 paddle::platform::complex64