// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/phi/core/hostdevice.h" #ifdef PADDLE_WITH_CUDA #include #include #endif // PADDLE_WITH_CUDA #ifdef PADDLE_WITH_HIP #include #include // NOLINT #endif #ifndef PADDLE_WITH_HIP #if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) #else #define PADDLE_ALIGN(x) __declspec(align(x)) #endif #else #define PADDLE_ALIGN(x) #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // todo #define PADDLE_WITH_CUDA_OR_HIP_COMPLEX #endif namespace phi { namespace dtype { template struct PADDLE_ALIGN(sizeof(T) * 2) complex { public: T real; T imag; using value_type = T; complex() = default; complex(const complex& o) = default; complex& operator=(const complex& o) = default; complex(complex&& o) = default; complex& operator=(complex&& o) = default; ~complex() = default; HOSTDEVICE complex(T real, T imag) : real(real), imag(imag) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template HOSTDEVICE inline explicit complex(const thrust::complex& c) { real = c.real(); imag = c.imag(); } template HOSTDEVICE inline explicit operator thrust::complex() const { return thrust::complex(real, imag); } #ifdef PADDLE_WITH_HIP HOSTDEVICE inline explicit operator hipFloatComplex() const { return make_hipFloatComplex(real, imag); } HOSTDEVICE inline explicit operator hipDoubleComplex() const { return make_hipDoubleComplex(real, imag); } #else HOSTDEVICE inline explicit operator cuFloatComplex() const { return make_cuFloatComplex(real, imag); } HOSTDEVICE inline explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real, imag); } #endif #endif template ::value || std::is_integral::value, int>::type = 0> HOSTDEVICE complex(const T1& val) { real = static_cast(val); imag = static_cast(0.0); } template HOSTDEVICE explicit complex( const typename std::enable_if::value, complex>::type& val) { real = val.real; imag = val.imag; } template HOSTDEVICE explicit complex( const typename std::enable_if::value, complex>::type& val) { real = val.real; imag = val.imag; } template HOSTDEVICE inline explicit operator std::complex() const { return static_cast>(std::complex(real, imag)); } template HOSTDEVICE complex(const std::complex& val) : real(val.real()), imag(val.imag()) {} template ::value || std::is_integral::value, int>::type = 0> HOSTDEVICE inline complex& operator=(const T1& val) { real = static_cast(val); imag = static_cast(0.0); return *this; } HOSTDEVICE inline explicit operator bool() const { return static_cast(this->real) || static_cast(this->imag); } HOSTDEVICE inline explicit operator int8_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator uint8_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator int16_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator uint16_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator int32_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator uint32_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator int64_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator uint64_t() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator float() const { return static_cast(this->real); } HOSTDEVICE inline explicit operator double() const { return static_cast(this->real); } }; template HOSTDEVICE inline complex operator+(const complex& a, const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::complex(a) + thrust::complex(b)); #else return complex(a.real + b.real, a.imag + b.imag); #endif } template HOSTDEVICE inline complex operator-(const complex& a, const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::complex(a) - thrust::complex(b)); #else return complex(a.real - b.real, a.imag - b.imag); #endif } template HOSTDEVICE inline complex operator*(const complex& a, const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::complex(a) * thrust::complex(b)); #else return complex(a.real * b.real - a.imag * b.imag, a.imag * b.real + b.imag * a.real); #endif } template HOSTDEVICE inline complex operator/(const complex& a, const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::complex(a) / thrust::complex(b)); #else T denominator = b.real * b.real + b.imag * b.imag; return complex((a.real * b.real + a.imag * b.imag) / denominator, (a.imag * b.real - a.real * b.imag) / denominator); #endif } template HOSTDEVICE inline complex operator-(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(-thrust::complex(a.real, a.imag)); #else complex res; res.real = -a.real; res.imag = -a.imag; return res; #endif } template HOSTDEVICE inline complex& operator+=(complex& a, // NOLINT const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) a = complex(thrust::complex(a.real, a.imag) += thrust::complex(b.real, b.imag)); return a; #else a.real += b.real; a.imag += b.imag; return a; #endif } template HOSTDEVICE inline complex& operator-=(complex& a, // NOLINT const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) a = complex(thrust::complex(a.real, a.imag) -= thrust::complex(b.real, b.imag)); return a; #else a.real -= b.real; a.imag -= b.imag; return a; #endif } template HOSTDEVICE inline complex& operator*=(complex& a, // NOLINT const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) a = complex(thrust::complex(a.real, a.imag) *= thrust::complex(b.real, b.imag)); return a; #else a.real = a.real * b.real - a.imag * b.imag; a.imag = a.imag * b.real + b.imag * a.real; return a; #endif } template HOSTDEVICE inline complex& operator/=(complex& a, // NOLINT const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) a = complex(thrust::complex(a.real, a.imag) /= thrust::complex(b.real, b.imag)); return a; #else T denominator = b.real * b.real + b.imag * b.imag; a.real = (a.real * b.real + a.imag * b.imag) / denominator; a.imag = (a.imag * b.real - a.real * b.imag) / denominator; return a; #endif } template HOSTDEVICE inline complex raw_uint16_to_complex64(uint16_t a) { complex res; res.real = a; res.imag = 0.0; return res; } template HOSTDEVICE inline bool operator==(const complex& a, const complex& b) { return a.real == b.real && a.imag == b.imag; } template HOSTDEVICE inline bool operator!=(const complex& a, const complex& b) { return a.real != b.real || a.imag != b.imag; } template HOSTDEVICE inline bool operator<(const complex& a, const complex& b) { return a.real < b.real; } template HOSTDEVICE inline bool operator<=(const complex& a, const complex& b) { return a.real <= b.real; } template HOSTDEVICE inline bool operator>(const complex& a, const complex& b) { return a.real > b.real; } template HOSTDEVICE inline bool operator>=(const complex& a, const complex& b) { return a.real >= b.real; } template HOSTDEVICE inline complex(max)(const complex& a, const complex& b) { return (a.real >= b.real) ? a : b; } template HOSTDEVICE inline complex(min)(const complex& a, const complex& b) { return (a.real < b.real) ? a : b; } template HOSTDEVICE inline bool(isnan)(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return ::isnan(a.real) || ::isnan(a.imag); #else return std::isnan(a.real) || std::isnan(a.imag); #endif } template HOSTDEVICE inline bool isinf(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return ::isinf(a.real) || ::isinf(a.imag); #else return std::isinf(a.real) || std::isinf(a.imag); #endif } template HOSTDEVICE inline bool isfinite(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return ::isfinite(a.real) || ::isfinite(a.imag); #else return std::isfinite(a.real) || std::isfinite(a.imag); #endif } template HOSTDEVICE inline T abs(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return thrust::abs(thrust::complex(a)); #else return std::abs(std::complex(a)); #endif } template HOSTDEVICE inline T arg(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return thrust::arg(thrust::complex(a)); #else return std::arg(std::complex(a)); #endif } template HOSTDEVICE inline complex pow(const complex& a, const complex& b) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::pow(thrust::complex(a), thrust::complex(b))); #else return complex(std::pow(std::complex(a), std::complex(b))); #endif } template HOSTDEVICE inline complex sqrt(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::sqrt(thrust::complex(a))); #else return complex(std::sqrt(std::complex(a))); #endif } template HOSTDEVICE inline complex sin(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::sin(thrust::complex(a))); #else return complex(std::sin(std::complex(a))); #endif } template HOSTDEVICE inline complex cos(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::cos(thrust::complex(a))); #else return complex(std::cos(std::complex(a))); #endif } template HOSTDEVICE inline complex tan(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::tan(thrust::complex(a))); #else return complex(std::tan(std::complex(a))); #endif } template HOSTDEVICE inline complex tanh(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::tanh(thrust::complex(a))); #else return complex(std::tanh(std::complex(a))); #endif } template HOSTDEVICE inline complex conj(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::conj(thrust::complex(a))); #else return complex(std::conj(std::complex(a))); #endif } template HOSTDEVICE inline complex exp(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::exp(thrust::complex(a))); #else return complex(std::exp(std::complex(a))); #endif } template HOSTDEVICE inline complex log(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ (defined(__CUDA_ARCH__) || defined(__HIPCC__)) return complex(thrust::log(thrust::complex(a))); #else return complex(std::log(std::complex(a))); #endif } template inline std::ostream& operator<<(std::ostream& os, const complex& a) { os << "real:" << a.real << " imag:" << a.imag; return os; } } // namespace dtype } // namespace phi namespace std { template struct is_pod> { static const bool value = true; }; template struct is_floating_point> : std::integral_constant {}; template struct is_signed> { static const bool value = false; }; template struct is_unsigned> { static const bool value = false; }; template inline bool isnan(const phi::dtype::complex& a) { return phi::dtype::isnan(a); } template inline bool isinf(const phi::dtype::complex& a) { return phi::dtype::isinf(a); } template struct numeric_limits> { static const bool is_specialized = false; static const bool is_signed = false; static const bool is_integer = false; static const bool is_exact = false; static const bool has_infinity = false; static const bool has_quiet_NaN = false; static const bool has_signaling_NaN = false; static const float_denorm_style has_denorm = denorm_absent; static const bool has_denorm_loss = false; static const std::float_round_style round_style = std::round_toward_zero; static const bool is_iec559 = false; static const bool is_bounded = false; static const bool is_modulo = false; static const int digits = 0; static const int digits10 = 0; static const int max_digits10 = 0; static const int radix = 0; static const int min_exponent = 0; static const int min_exponent10 = 0; static const int max_exponent = 0; static const int max_exponent10 = 0; static const bool traps = false; static const bool tinyness_before = false; static phi::dtype::complex(min)() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex lowest() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex(max)() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex epsilon() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex round_error() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex infinity() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex quiet_NaN() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex signaling_NaN() { return phi::dtype::complex(0.0, 0.0); } static phi::dtype::complex denorm_min() { return phi::dtype::complex(0.0, 0.0); } }; } // namespace std