complex.h 15.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <stdint.h>

#include <complex>
#include <cstring>
#include <iostream>
#include <limits>
23
#include "paddle/phi/core/hostdevice.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
#ifdef PADDLE_WITH_CUDA
#include <cuComplex.h>
#include <thrust/complex.h>
#endif  // PADDLE_WITH_CUDA

#ifdef PADDLE_WITH_HIP
#include <hip/hip_complex.h>
#include <thrust/complex.h>  // NOLINT
#endif

#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// todo
#define PADDLE_WITH_CUDA_OR_HIP_COMPLEX
#endif

45
namespace phi {
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
namespace dtype {

template <typename T>
struct PADDLE_ALIGN(sizeof(T) * 2) complex {
 public:
  T real;
  T imag;

  using value_type = T;

  complex() = default;
  complex(const complex<T>& o) = default;
  complex& operator=(const complex<T>& o) = default;
  complex(complex<T>&& o) = default;
  complex& operator=(complex<T>&& o) = default;
  ~complex() = default;

  HOSTDEVICE complex(T real, T imag) : real(real), imag(imag) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

  template <typename T1>
  HOSTDEVICE inline explicit complex(const thrust::complex<T1>& c) {
    real = c.real();
    imag = c.imag();
  }

  template <typename T1>
  HOSTDEVICE inline explicit operator thrust::complex<T1>() const {
    return thrust::complex<T1>(real, imag);
  }

#ifdef PADDLE_WITH_HIP
  HOSTDEVICE inline explicit operator hipFloatComplex() const {
    return make_hipFloatComplex(real, imag);
  }

  HOSTDEVICE inline explicit operator hipDoubleComplex() const {
    return make_hipDoubleComplex(real, imag);
  }
#else
  HOSTDEVICE inline explicit operator cuFloatComplex() const {
    return make_cuFloatComplex(real, imag);
  }

  HOSTDEVICE inline explicit operator cuDoubleComplex() const {
    return make_cuDoubleComplex(real, imag);
  }
#endif
#endif

  template <typename T1,
            typename std::enable_if<std::is_floating_point<T1>::value ||
                                        std::is_integral<T1>::value,
                                    int>::type = 0>
  HOSTDEVICE complex(const T1& val) {
    real = static_cast<T>(val);
    imag = static_cast<T>(0.0);
  }

  template <typename T1 = T>
  HOSTDEVICE explicit complex(
Y
Yuanle Liu 已提交
108 109
      const typename std::enable_if<std::is_same<T1, float>::value,
                                    complex<double>>::type& val) {
110 111 112 113 114 115
    real = val.real;
    imag = val.imag;
  }

  template <typename T1 = T>
  HOSTDEVICE explicit complex(
Y
Yuanle Liu 已提交
116 117
      const typename std::enable_if<std::is_same<T1, double>::value,
                                    complex<float>>::type& val) {
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    real = val.real;
    imag = val.imag;
  }

  template <typename T1>
  HOSTDEVICE inline explicit operator std::complex<T1>() const {
    return static_cast<std::complex<T1>>(std::complex<T>(real, imag));
  }

  template <typename T1>
  HOSTDEVICE complex(const std::complex<T1>& val)
      : real(val.real()), imag(val.imag()) {}

  template <typename T1,
            typename std::enable_if<std::is_floating_point<T1>::value ||
                                        std::is_integral<T1>::value,
                                    int>::type = 0>
  HOSTDEVICE inline complex& operator=(const T1& val) {
    real = static_cast<T>(val);
    imag = static_cast<T>(0.0);
    return *this;
  }

  HOSTDEVICE inline explicit operator bool() const {
    return static_cast<bool>(this->real) || static_cast<bool>(this->imag);
  }

  HOSTDEVICE inline explicit operator int8_t() const {
    return static_cast<int8_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint8_t() const {
    return static_cast<uint8_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int16_t() const {
    return static_cast<int16_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint16_t() const {
    return static_cast<uint16_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int32_t() const {
    return static_cast<int32_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint32_t() const {
    return static_cast<uint32_t>(this->real);
  }

  HOSTDEVICE inline explicit operator int64_t() const {
    return static_cast<int64_t>(this->real);
  }

  HOSTDEVICE inline explicit operator uint64_t() const {
    return static_cast<uint64_t>(this->real);
  }

  HOSTDEVICE inline explicit operator float() const {
    return static_cast<float>(this->real);
  }

  HOSTDEVICE inline explicit operator double() const {
    return static_cast<double>(this->real);
  }
};

template <typename T>
HOSTDEVICE inline complex<T> operator+(const complex<T>& a,
                                       const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::complex<T>(a) + thrust::complex<T>(b));
#else
  return complex<T>(a.real + b.real, a.imag + b.imag);
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> operator-(const complex<T>& a,
                                       const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::complex<T>(a) - thrust::complex<T>(b));
#else
  return complex<T>(a.real - b.real, a.imag - b.imag);
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> operator*(const complex<T>& a,
                                       const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::complex<T>(a) * thrust::complex<T>(b));
#else
  return complex<T>(a.real * b.real - a.imag * b.imag,
                    a.imag * b.real + b.imag * a.real);
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> operator/(const complex<T>& a,
                                       const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::complex<T>(a) / thrust::complex<T>(b));
#else
  T denominator = b.real * b.real + b.imag * b.imag;
  return complex<T>((a.real * b.real + a.imag * b.imag) / denominator,
                    (a.imag * b.real - a.real * b.imag) / denominator);
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> operator-(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(-thrust::complex<T>(a.real, a.imag));
#else
  complex<T> res;
  res.real = -a.real;
  res.imag = -a.imag;
  return res;
#endif
}

template <typename T>
HOSTDEVICE inline complex<T>& operator+=(complex<T>& a,  // NOLINT
                                         const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  a = complex<T>(thrust::complex<T>(a.real, a.imag) +=
                 thrust::complex<T>(b.real, b.imag));
  return a;
#else
  a.real += b.real;
  a.imag += b.imag;
  return a;
#endif
}

template <typename T>
HOSTDEVICE inline complex<T>& operator-=(complex<T>& a,  // NOLINT
                                         const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  a = complex<T>(thrust::complex<T>(a.real, a.imag) -=
                 thrust::complex<T>(b.real, b.imag));
  return a;
#else
  a.real -= b.real;
  a.imag -= b.imag;
  return a;
#endif
}

template <typename T>
HOSTDEVICE inline complex<T>& operator*=(complex<T>& a,  // NOLINT
                                         const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  a = complex<T>(thrust::complex<T>(a.real, a.imag) *=
                 thrust::complex<T>(b.real, b.imag));
  return a;
#else
  a.real = a.real * b.real - a.imag * b.imag;
  a.imag = a.imag * b.real + b.imag * a.real;
  return a;
#endif
}

template <typename T>
HOSTDEVICE inline complex<T>& operator/=(complex<T>& a,  // NOLINT
                                         const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  a = complex<T>(thrust::complex<T>(a.real, a.imag) /=
                 thrust::complex<T>(b.real, b.imag));
  return a;
#else
  T denominator = b.real * b.real + b.imag * b.imag;
  a.real = (a.real * b.real + a.imag * b.imag) / denominator;
  a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
  return a;
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> raw_uint16_to_complex64(uint16_t a) {
  complex<T> res;
  res.real = a;
  res.imag = 0.0;
  return res;
}

template <typename T>
HOSTDEVICE inline bool operator==(const complex<T>& a, const complex<T>& b) {
  return a.real == b.real && a.imag == b.imag;
}

template <typename T>
HOSTDEVICE inline bool operator!=(const complex<T>& a, const complex<T>& b) {
  return a.real != b.real || a.imag != b.imag;
}

template <typename T>
HOSTDEVICE inline bool operator<(const complex<T>& a, const complex<T>& b) {
  return a.real < b.real;
}

template <typename T>
HOSTDEVICE inline bool operator<=(const complex<T>& a, const complex<T>& b) {
  return a.real <= b.real;
}

template <typename T>
HOSTDEVICE inline bool operator>(const complex<T>& a, const complex<T>& b) {
  return a.real > b.real;
}

template <typename T>
HOSTDEVICE inline bool operator>=(const complex<T>& a, const complex<T>& b) {
  return a.real >= b.real;
}

template <typename T>
HOSTDEVICE inline complex<T>(max)(const complex<T>& a, const complex<T>& b) {
  return (a.real >= b.real) ? a : b;
}

template <typename T>
HOSTDEVICE inline complex<T>(min)(const complex<T>& a, const complex<T>& b) {
  return (a.real < b.real) ? a : b;
}

template <typename T>
HOSTDEVICE inline bool(isnan)(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return ::isnan(a.real) || ::isnan(a.imag);
#else
  return std::isnan(a.real) || std::isnan(a.imag);
#endif
}

template <typename T>
HOSTDEVICE inline bool isinf(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return ::isinf(a.real) || ::isinf(a.imag);
#else
  return std::isinf(a.real) || std::isinf(a.imag);
#endif
}

template <typename T>
HOSTDEVICE inline bool isfinite(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return ::isfinite(a.real) || ::isfinite(a.imag);
#else
  return std::isfinite(a.real) || std::isfinite(a.imag);
#endif
}

template <typename T>
HOSTDEVICE inline T abs(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return thrust::abs(thrust::complex<T>(a));
#else
  return std::abs(std::complex<T>(a));
#endif
}

template <typename T>
HOSTDEVICE inline T arg(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return thrust::arg(thrust::complex<T>(a));
#else
  return std::arg(std::complex<T>(a));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> pow(const complex<T>& a, const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::pow(thrust::complex<T>(a), thrust::complex<T>(b)));
#else
  return complex<T>(std::pow(std::complex<T>(a), std::complex<T>(b)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> sqrt(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::sqrt(thrust::complex<T>(a)));
#else
  return complex<T>(std::sqrt(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::tanh(thrust::complex<T>(a)));
#else
  return complex<T>(std::tanh(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
    (defined(__CUDA_ARCH__) || defined(__HIPCC__))
  return complex<T>(thrust::log(thrust::complex<T>(a)));
#else
  return complex<T>(std::log(std::complex<T>(a)));
#endif
}

template <typename T>
inline std::ostream& operator<<(std::ostream& os, const complex<T>& a) {
  os << "real:" << a.real << " imag:" << a.imag;
  return os;
}
}  // namespace dtype
451
}  // namespace phi
452 453 454 455

namespace std {

template <typename T>
456
struct is_pod<phi::dtype::complex<T>> {
457 458 459 460
  static const bool value = true;
};

template <typename T>
461
struct is_floating_point<phi::dtype::complex<T>>
462 463 464
    : std::integral_constant<bool, false> {};

template <typename T>
465
struct is_signed<phi::dtype::complex<T>> {
466 467 468 469
  static const bool value = false;
};

template <typename T>
470
struct is_unsigned<phi::dtype::complex<T>> {
471 472 473 474
  static const bool value = false;
};

template <typename T>
475 476
inline bool isnan(const phi::dtype::complex<T>& a) {
  return phi::dtype::isnan(a);
477 478 479
}

template <typename T>
480 481
inline bool isinf(const phi::dtype::complex<T>& a) {
  return phi::dtype::isinf(a);
482 483 484
}

template <typename T>
485
struct numeric_limits<phi::dtype::complex<T>> {
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
  static const bool is_specialized = false;
  static const bool is_signed = false;
  static const bool is_integer = false;
  static const bool is_exact = false;
  static const bool has_infinity = false;
  static const bool has_quiet_NaN = false;
  static const bool has_signaling_NaN = false;
  static const float_denorm_style has_denorm = denorm_absent;
  static const bool has_denorm_loss = false;
  static const std::float_round_style round_style = std::round_toward_zero;
  static const bool is_iec559 = false;
  static const bool is_bounded = false;
  static const bool is_modulo = false;
  static const int digits = 0;
  static const int digits10 = 0;
  static const int max_digits10 = 0;
  static const int radix = 0;
  static const int min_exponent = 0;
  static const int min_exponent10 = 0;
  static const int max_exponent = 0;
  static const int max_exponent10 = 0;
  static const bool traps = false;
  static const bool tinyness_before = false;

510 511
  static phi::dtype::complex<T>(min)() {
    return phi::dtype::complex<T>(0.0, 0.0);
512
  }
513 514
  static phi::dtype::complex<T> lowest() {
    return phi::dtype::complex<T>(0.0, 0.0);
515
  }
516 517
  static phi::dtype::complex<T>(max)() {
    return phi::dtype::complex<T>(0.0, 0.0);
518
  }
519 520
  static phi::dtype::complex<T> epsilon() {
    return phi::dtype::complex<T>(0.0, 0.0);
521
  }
522 523
  static phi::dtype::complex<T> round_error() {
    return phi::dtype::complex<T>(0.0, 0.0);
524
  }
525 526
  static phi::dtype::complex<T> infinity() {
    return phi::dtype::complex<T>(0.0, 0.0);
527
  }
528 529
  static phi::dtype::complex<T> quiet_NaN() {
    return phi::dtype::complex<T>(0.0, 0.0);
530
  }
531 532
  static phi::dtype::complex<T> signaling_NaN() {
    return phi::dtype::complex<T>(0.0, 0.0);
533
  }
534 535
  static phi::dtype::complex<T> denorm_min() {
    return phi::dtype::complex<T>(0.0, 0.0);
536 537 538 539
  }
};

}  // namespace std