From 738bf20e4dec299cde18d67ed36ea6ddc88a7ba6 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 20 May 2021 14:16:05 +0800 Subject: [PATCH] Add complex template type (#32857) * add complex template file * add numtraits for complex template * add complex template type register * modify specify template of complex * modify specify template of complex * modify specify template of complex * modify specify template of complex * make TensorCheckerVisitor support complex type * fix operator= error * add complex template * add complex template type * add complex template type to pyarray transform * add complex template type to pyarray transform * remove complex type for dlpack register * set dlpack supprot complex type * set dlpack supprot complex type * set dlpack supprot complex type * remove explict for complex constructor * add complex unit test file --- paddle/fluid/framework/data_type.h | 11 + .../framework/details/nan_inf_utils_detail.cc | 64 ++- .../framework/details/nan_inf_utils_detail.cu | 6 +- .../framework/details/nan_inf_utils_detail.h | 8 +- paddle/fluid/framework/dlpack_tensor.cc | 16 +- paddle/fluid/framework/dlpack_tensor_test.cc | 7 + .../fluid/operators/math/concat_and_split.h | 26 +- paddle/fluid/operators/math/math_function.cc | 14 + paddle/fluid/operators/math/math_function.cu | 10 + paddle/fluid/platform/CMakeLists.txt | 2 + paddle/fluid/platform/complex.h | 537 ++++++++++++++++++ paddle/fluid/platform/complex_test.cc | 324 +++++++++++ paddle/fluid/platform/complex_test.cu | 361 ++++++++++++ paddle/fluid/platform/eigen_ext.h | 179 ++++++ paddle/fluid/pybind/tensor_py.h | 51 ++ tools/parallel_UT_rule.py | 1 + 16 files changed, 1598 insertions(+), 19 deletions(-) create mode 100644 paddle/fluid/platform/complex.h create mode 100644 paddle/fluid/platform/complex_test.cc create mode 100644 paddle/fluid/platform/complex_test.cu diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index c8f73a5469..648a32420a 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/eigen_ext.h" @@ -30,6 +31,8 @@ struct bfloat16; struct complex128; struct complex64; struct float16; +template +struct complex; } // namespace platform } // namespace paddle @@ -61,6 +64,10 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ _ForEachDataTypeHelper_(callback, int16_t, INT16); \ _ForEachDataTypeHelper_(callback, int8_t, INT8); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX128); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128); @@ -69,6 +76,10 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, double, FP64); \ _ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX128); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128); diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index 0fdb97db20..829772448e 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -163,6 +163,11 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, omp_in) #pragma omp declare reduction(+ : paddle::platform::complex128 : omp_out += \ omp_in) +#pragma omp declare reduction(+ : paddle::platform::complex < \ + float > : omp_out += omp_in) +#pragma omp declare reduction(+ : paddle::platform::complex < \ + double > : omp_out += omp_in) + #endif template @@ -268,12 +273,69 @@ void CheckNanInf( op_type)); } } + +template <> +void CheckNanInf>( + const paddle::platform::complex* value, const size_t numel, + int print_num, const std::string& op_type, const std::string& var_name) { + float real_sum = 0.0f; +#pragma omp parallel for reduction(+ : real_sum) + for (size_t i = 0; i < numel; ++i) { + real_sum += (value[i].real - value[i].real); + } + + float imag_sum = 0.0f; +#pragma omp parallel for reduction(+ : imag_sum) + for (size_t i = 0; i < numel; ++i) { + imag_sum += (value[i].imag - value[i].imag); + } + + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); + } +} + +template <> + void CheckNanInf>> + (const paddle::platform::complex* value, const size_t numel, + int print_num, const std::string& op_type, const std::string& var_name) { + double real_sum = 0.0; +#pragma omp parallel for reduction(+ : real_sum) + for (size_t i = 0; i < numel; ++i) { + real_sum += (value[i].real - value[i].real); + } + + double imag_sum = 0.0; +#pragma omp parallel for reduction(+ : imag_sum) + for (size_t i = 0; i < numel; ++i) { + imag_sum += (value[i].imag - value[i].imag); + } + + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); + } +} + #endif template <> template void TensorCheckerVisitor::apply( - typename std::enable_if::value>::type*) const { + typename std::enable_if< + std::is_floating_point::value || + std::is_same>::value || + std::is_same>::value>::type*) + const { // use env strategy control in future, -1=print_all. int print_num = 3; CheckNanInf(tensor_.data(), tensor_.numel(), print_num, op_type_, diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu index 96d1a9fb94..a9ea336e42 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -123,7 +123,11 @@ __global__ void CheckNanInfKernel(const T* value, const size_t numel, template <> template void TensorCheckerVisitor::apply( - typename std::enable_if::value>::type*) const { + typename std::enable_if< + std::is_floating_point::value || + std::is_same>::value || + std::is_same>::value>::type*) + const { int print_num = 3; auto* dev_ctx = reinterpret_cast( diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.h b/paddle/fluid/framework/details/nan_inf_utils_detail.h index b4459e5a7c..10b7ab0bc9 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.h +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.h @@ -46,8 +46,12 @@ struct TensorCheckerVisitor { } template - void apply(typename std::enable_if::value>::type* = - 0) const; + void apply( + typename std::enable_if< + std::is_floating_point::value || + std::is_same>::value || + std::is_same>::value>::type* = + 0) const; std::string op_type_; std::string var_name_; diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 3833b027d2..54d8fc92b2 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -28,9 +28,19 @@ namespace internal { template static ::DLDataType GetDLDataTypeCode() { ::DLDataType dtype; - if (std::is_same::value || - std::is_same::value || - std::is_floating_point::value) { + if (std::is_same>::value || + std::is_same>::value || + std::is_same::value || + std::is_same::value) { + // The current dlpack library version is v0.2, and does not define + // kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set + // dtype.code to 5U directly here. After the dlpack library version being + // upgraded to v0.4, it should be written as follow. + // dtype.code = kDLComplex; + dtype.code = 5U; + } else if (std::is_same::value || + std::is_same::value || + std::is_floating_point::value) { dtype.code = kDLFloat; } else if (std::is_unsigned::value) { dtype.code = kDLUInt; diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/paddle/fluid/framework/dlpack_tensor_test.cc index d03437034d..1a79ada0be 100644 --- a/paddle/fluid/framework/dlpack_tensor_test.cc +++ b/paddle/fluid/framework/dlpack_tensor_test.cc @@ -28,6 +28,13 @@ namespace framework { namespace { // NOLINT template constexpr uint8_t GetDLDataTypeCode() { + if (std::is_same>::value || + std::is_same>::value || + std::is_same::value || + std::is_same::value) { + return static_cast(5); + } + return std::is_same::value || std::is_floating_point::value ? static_cast(kDLFloat) diff --git a/paddle/fluid/operators/math/concat_and_split.h b/paddle/fluid/operators/math/concat_and_split.h index d6ad3aec22..a79a9da0b3 100644 --- a/paddle/fluid/operators/math/concat_and_split.h +++ b/paddle/fluid/operators/math/concat_and_split.h @@ -65,16 +65,18 @@ class SplitFunctor { } // namespace operators } // namespace paddle -#define FOR_ALL_TYPES(macro) \ - macro(int); \ - macro(float); \ - macro(double); \ - macro(bool); \ - macro(int64_t); \ - macro(int16_t); \ - macro(uint8_t); \ - macro(int8_t); \ - macro(::paddle::platform::float16); \ - macro(::paddle::platform::bfloat16); \ - macro(::paddle::platform::complex64); \ +#define FOR_ALL_TYPES(macro) \ + macro(int); \ + macro(float); \ + macro(double); \ + macro(bool); \ + macro(int64_t); \ + macro(int16_t); \ + macro(uint8_t); \ + macro(int8_t); \ + macro(::paddle::platform::float16); \ + macro(::paddle::platform::bfloat16); \ + macro(::paddle::platform::complex); \ + macro(::paddle::platform::complex); \ + macro(::paddle::platform::complex64); \ macro(::paddle::platform::complex128) diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 56217b4dc7..d01a39ecb7 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -47,6 +47,10 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; #ifdef PADDLE_WITH_XPU template struct SetConstant; @@ -59,6 +63,10 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; #endif #define DEFINE_CPU_TRANS(RANK) \ @@ -74,6 +82,10 @@ template struct SetConstant; template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ + template struct Transpose, RANK>; \ + template struct Transpose, RANK>; \ template struct Transpose; \ template struct Transpose); +DEFINE_CPU_TRANS_NORMAL(platform::complex); struct TensorSetConstantCPU { TensorSetConstantCPU(framework::Tensor* tensor, float value) diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index f94c1bf696..c5c78c87f7 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -43,6 +43,10 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; #define DEFINE_GPU_TRANS(RANK) \ template struct Transpose; \ @@ -52,6 +56,10 @@ template struct SetConstant; template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ + template struct Transpose, RANK>; \ + template struct Transpose, RANK>; \ template struct Transpose; \ template struct Transpose; @@ -145,6 +153,8 @@ DEFINE_GPU_TRANS_NORMAL(uint8_t); DEFINE_GPU_TRANS_NORMAL(int8_t); DEFINE_GPU_TRANS_NORMAL(complex64); DEFINE_GPU_TRANS_NORMAL(complex128); +DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); +DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex); struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 0827d6a5ae..12a54fd7e8 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -187,10 +187,12 @@ endif() cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor) cc_test(bfloat16_test SRCS bfloat16_test.cc DEPS lod_tensor) +cc_test(complex_test SRCS complex_test.cc DEPS lod_tensor) IF(WITH_GPU) nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor) nv_test(bfloat16_gpu_test SRCS bfloat16_test.cu DEPS lod_tensor) + nv_test(complex_gpu_test SRCS complex_test.cu DEPS lod_tensor) nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags) nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info) ENDIF() diff --git a/paddle/fluid/platform/complex.h b/paddle/fluid/platform/complex.h new file mode 100644 index 0000000000..2c1b42ea48 --- /dev/null +++ b/paddle/fluid/platform/complex.h @@ -0,0 +1,537 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif // PADDLE_WITH_CUDA + +#ifdef PADDLE_WITH_HIP +#include +#include // NOLINT +#endif + +#if !defined(_WIN32) +#define PADDLE_ALIGN(x) __attribute__((aligned(x))) +#else +#define PADDLE_ALIGN(x) __declspec(align(x)) +#endif + +#if (defined(__CUDACC__) || defined(__HIPCC__)) +#define HOSTDEVICE __host__ __device__ +#define DEVICE __device__ +#define HOST __host__ +#else +#define HOSTDEVICE +#define DEVICE +#define HOST +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +// todo +#define PADDLE_WITH_CUDA_OR_HIP_COMPLEX +#endif + +namespace paddle { +namespace platform { + +template +struct PADDLE_ALIGN(sizeof(T) * 2) complex { + public: + T real; + T imag; + + complex() = default; + complex(const complex& o) = default; + complex& operator=(const complex& o) = default; + complex(complex&& o) = default; + complex& operator=(complex&& o) = default; + ~complex() = default; + + HOSTDEVICE complex(T real, T imag) : real(real), imag(imag) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + + template + HOSTDEVICE inline explicit complex(const thrust::complex& c) { + real = c.real(); + imag = c.imag(); + } + + template + HOSTDEVICE inline explicit operator thrust::complex() const { + return thrust::complex(real, imag); + } + +#ifdef PADDLE_WITH_HIP + HOSTDEVICE inline explicit operator hipFloatComplex() const { + return make_hipFloatComplex(real, imag); + } + + HOSTDEVICE inline explicit operator hipDoubleComplex() const { + return make_hipDoubleComplex(real, imag); + } +#else + HOSTDEVICE inline explicit operator cuFloatComplex() const { + return make_cuFloatComplex(real, imag); + } + + HOSTDEVICE inline explicit operator cuDoubleComplex() const { + return make_cuDoubleComplex(real, imag); + } +#endif +#endif + + template ::value || + std::is_integral::value, + int>::type = 0> + HOSTDEVICE complex(const T1& val) { + real = static_cast(val); + imag = static_cast(0.0); + } + + template + HOSTDEVICE explicit complex( + const std::enable_if_t::value, complex>& + val) { + real = val.real; + imag = val.imag; + } + + template + HOSTDEVICE explicit complex( + const std::enable_if_t::value, complex>& + val) { + real = val.real; + imag = val.imag; + } + + template + HOSTDEVICE inline explicit operator std::complex() const { + return static_cast>(std::complex(real, imag)); + } + + template + HOSTDEVICE complex(const std::complex& val) + : real(val.real()), imag(val.imag()) {} + + template ::value || + std::is_integral::value, + int>::type = 0> + HOSTDEVICE inline complex& operator=(const T1& val) { + real = static_cast(val); + imag = static_cast(0.0); + return *this; + } + + HOSTDEVICE inline explicit operator bool() const { + return static_cast(this->real) || static_cast(this->imag); + } + + HOSTDEVICE inline explicit operator int8_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator uint8_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator int16_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator uint16_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator int32_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator uint32_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator int64_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator uint64_t() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator float() const { + return static_cast(this->real); + } + + HOSTDEVICE inline explicit operator double() const { + return static_cast(this->real); + } +}; + +template +HOSTDEVICE inline complex operator+(const complex& a, + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::complex(a) + thrust::complex(b)); +#else + return complex(a.real + b.real, a.imag + b.imag); +#endif +} + +template +HOSTDEVICE inline complex operator-(const complex& a, + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::complex(a) - thrust::complex(b)); +#else + return complex(a.real - b.real, a.imag - b.imag); +#endif +} + +template +HOSTDEVICE inline complex operator*(const complex& a, + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::complex(a) * thrust::complex(b)); +#else + return complex(a.real * b.real - a.imag * b.imag, + a.imag * b.real + b.imag * a.real); +#endif +} + +template +HOSTDEVICE inline complex operator/(const complex& a, + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::complex(a) / thrust::complex(b)); +#else + T denominator = b.real * b.real + b.imag * b.imag; + return complex((a.real * b.real + a.imag * b.imag) / denominator, + (a.imag * b.real - a.real * b.imag) / denominator); +#endif +} + +template +HOSTDEVICE inline complex operator-(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(-thrust::complex(a.real, a.imag)); +#else + complex res; + res.real = -a.real; + res.imag = -a.imag; + return res; +#endif +} + +template +HOSTDEVICE inline complex& operator+=(complex& a, // NOLINT + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + a = complex(thrust::complex(a.real, a.imag) += + thrust::complex(b.real, b.imag)); + return a; +#else + a.real += b.real; + a.imag += b.imag; + return a; +#endif +} + +template +HOSTDEVICE inline complex& operator-=(complex& a, // NOLINT + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + a = complex(thrust::complex(a.real, a.imag) -= + thrust::complex(b.real, b.imag)); + return a; +#else + a.real -= b.real; + a.imag -= b.imag; + return a; +#endif +} + +template +HOSTDEVICE inline complex& operator*=(complex& a, // NOLINT + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + a = complex(thrust::complex(a.real, a.imag) *= + thrust::complex(b.real, b.imag)); + return a; +#else + a.real = a.real * b.real - a.imag * b.imag; + a.imag = a.imag * b.real + b.imag * a.real; + return a; +#endif +} + +template +HOSTDEVICE inline complex& operator/=(complex& a, // NOLINT + const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + a = complex(thrust::complex(a.real, a.imag) /= + thrust::complex(b.real, b.imag)); + return a; +#else + T denominator = b.real * b.real + b.imag * b.imag; + a.real = (a.real * b.real + a.imag * b.imag) / denominator; + a.imag = (a.imag * b.real - a.real * b.imag) / denominator; + return a; +#endif +} + +template +HOSTDEVICE inline complex raw_uint16_to_complex64(uint16_t a) { + complex res; + res.real = a; + res.imag = 0.0; + return res; +} + +template +HOSTDEVICE inline bool operator==(const complex& a, const complex& b) { + return a.real == b.real && a.imag == b.imag; +} + +template +HOSTDEVICE inline bool operator!=(const complex& a, const complex& b) { + return a.real != b.real || a.imag != b.imag; +} + +template +HOSTDEVICE inline bool operator<(const complex& a, const complex& b) { + return a.real < b.real; +} + +template +HOSTDEVICE inline bool operator<=(const complex& a, const complex& b) { + return a.real <= b.real; +} + +template +HOSTDEVICE inline bool operator>(const complex& a, const complex& b) { + return a.real > b.real; +} + +template +HOSTDEVICE inline bool operator>=(const complex& a, const complex& b) { + return a.real >= b.real; +} + +template +HOSTDEVICE inline complex max(const complex& a, const complex& b) { + return (a.real >= b.real) ? a : b; +} + +template +HOSTDEVICE inline complex min(const complex& a, const complex& b) { + return (a.real < b.real) ? a : b; +} + +template +HOSTDEVICE inline bool(isnan)(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return ::isnan(a.real) || ::isnan(a.imag); +#else + return std::isnan(a.real) || std::isnan(a.imag); +#endif +} + +template +HOSTDEVICE inline bool isinf(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return ::isinf(a.real) || ::isinf(a.imag); +#else + return std::isinf(a.real) || std::isinf(a.imag); +#endif +} + +template +HOSTDEVICE inline bool isfinite(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return ::isfinite(a.real) || ::isfinite(a.imag); +#else + return std::isfinite(a.real) || std::isfinite(a.imag); +#endif +} + +template +HOSTDEVICE inline T abs(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return thrust::abs(thrust::complex(a)); +#else + return std::abs(std::complex(a)); +#endif +} + +template +HOSTDEVICE inline complex pow(const complex& a, const complex& b) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::pow(thrust::complex(a), thrust::complex(b))); +#else + return complex(std::pow(std::complex(a), std::complex(b))); +#endif +} + +template +HOSTDEVICE inline complex sqrt(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::sqrt(thrust::complex(a))); +#else + return complex(std::sqrt(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex tanh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::tanh(thrust::complex(a))); +#else + return complex(std::tanh(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex log(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::log(thrust::complex(a))); +#else + return complex(std::log(std::complex(a))); +#endif +} + +template +inline std::ostream& operator<<(std::ostream& os, const complex& a) { + os << "real:" << a.real << " imag:" << a.imag; + return os; +} + +} // namespace platform +} // namespace paddle + +namespace std { + +template +struct is_pod> { + static const bool value = true; +}; + +template +struct is_floating_point> + : std::integral_constant {}; + +template +struct is_signed> { + static const bool value = false; +}; + +template +struct is_unsigned> { + static const bool value = false; +}; + +template +inline bool isnan(const paddle::platform::complex& a) { + return paddle::platform::isnan(a); +} + +template +inline bool isinf(const paddle::platform::complex& a) { + return paddle::platform::isinf(a); +} + +template +struct numeric_limits> { + static const bool is_specialized = false; + static const bool is_signed = false; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = false; + static const bool has_quiet_NaN = false; + static const bool has_signaling_NaN = false; + static const float_denorm_style has_denorm = denorm_absent; + static const bool has_denorm_loss = false; + static const std::float_round_style round_style = std::round_toward_zero; + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 0; + static const int digits10 = 0; + static const int max_digits10 = 0; + static const int radix = 0; + static const int min_exponent = 0; + static const int min_exponent10 = 0; + static const int max_exponent = 0; + static const int max_exponent10 = 0; + static const bool traps = false; + static const bool tinyness_before = false; + + static paddle::platform::complex min() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex lowest() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex max() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex epsilon() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex round_error() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex infinity() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex quiet_NaN() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex signaling_NaN() { + return paddle::platform::complex(0.0, 0.0); + } + static paddle::platform::complex denorm_min() { + return paddle::platform::complex(0.0, 0.0); + } +}; + +} // namespace std diff --git a/paddle/fluid/platform/complex_test.cc b/paddle/fluid/platform/complex_test.cc new file mode 100644 index 0000000000..4d13161e94 --- /dev/null +++ b/paddle/fluid/platform/complex_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/complex.h" +#include +#include "paddle/fluid/platform/eigen_ext.h" + +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +TEST(complex, conversion_cpu) { + // *********** complex ************* + // float to complex + EXPECT_EQ(complex().real, 0.0f); + EXPECT_EQ(complex().imag, 0.0f); + + EXPECT_EQ(complex(1.0f, 1.0f).real, 1.0f); + EXPECT_EQ(complex(1.0f, 1.0f).imag, 1.0f); + EXPECT_EQ(complex(0.0f, 1.0f).real, 0.0f); + EXPECT_EQ(complex(0.0f, 1.0f).imag, 1.0f); + + EXPECT_EQ(complex(1.0f).real, 1.0f); + EXPECT_EQ(complex(1.0f).imag, 0.0f); + + // int to complex + EXPECT_EQ(complex(1).real, 1.0f); + EXPECT_EQ(complex(0).real, 0.0f); + EXPECT_EQ(complex(2).real, 2.0f); + EXPECT_EQ(complex(-2).real, -2.0f); + + // bool to complex + EXPECT_EQ(complex(true).real, 1.0f); + EXPECT_EQ(complex(true).imag, 0.0f); + + // complex to complex + EXPECT_EQ(complex(complex(1.0, 2.0)).real, 1.0f); + EXPECT_EQ(complex(complex(1.0, 2.0)).imag, 2.0f); + + // std::complex to complex + EXPECT_EQ(complex(std::complex(1.0f, 2.0f)).real, 1.0f); + EXPECT_EQ(complex(std::complex(1.0f, 2.0f)).imag, 2.0f); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0f); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0f); + + // Assignment operator + complex c = 1.0f; + EXPECT_EQ(c.real, 1.0f); + EXPECT_EQ(c.imag, 0.0f); + c = complex(2.0, 2.0); + EXPECT_EQ(c.real, 2.0f); + EXPECT_EQ(c.imag, 2.0f); + + // Conversion operator + EXPECT_EQ(static_cast(complex(0.5f)), 0.5f); + EXPECT_NEAR(static_cast(complex(0.33333)), 0.33333, 0.01); + EXPECT_EQ(static_cast(complex(-1)), -1); + EXPECT_EQ(static_cast(complex(true)), true); + + // *********** complex ************* + // double to complex + EXPECT_EQ(complex().real, 0.0); + EXPECT_EQ(complex().imag, 0.0); + + EXPECT_EQ(complex(1.0, 1.0).real, 1.0); + EXPECT_EQ(complex(1.0, 1.0).imag, 1.0); + EXPECT_EQ(complex(0.0, 1.0).real, 0.0); + EXPECT_EQ(complex(0.0, 1.0).imag, 1.0); + + EXPECT_EQ(complex(1.0).real, 1.0); + EXPECT_EQ(complex(1.0).imag, 0.0); + + // int to complex + EXPECT_EQ(complex(1).real, 1.0); + EXPECT_EQ(complex(0).real, 0.0); + EXPECT_EQ(complex(2).real, 2.0); + EXPECT_EQ(complex(-2).real, -2.0); + + // bool to complex + EXPECT_EQ(complex(true).real, 1.0); + EXPECT_EQ(complex(true).imag, 0.0); + + // complex to complex + EXPECT_EQ(complex(complex(1.0f, 2.0f)).real, 1.0); + EXPECT_EQ(complex(complex(1.0f, 2.0f)).imag, 2.0); + + // std::complex to complex + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0); + + // Assignment operator + complex c1 = 1.0; + EXPECT_EQ(c1.real, 1.0); + EXPECT_EQ(c1.imag, 0.0); + c1 = complex(2.0, 2.0); + EXPECT_EQ(c1.real, 2.0); + EXPECT_EQ(c1.imag, 2.0); + + // Conversion operator + EXPECT_EQ(static_cast(complex(0.5)), 0.5); + EXPECT_NEAR(static_cast(complex(0.33333)), 0.33333, 0.01); + EXPECT_EQ(static_cast(complex(-1)), -1); + EXPECT_EQ(static_cast(complex(true)), true); +} + +TEST(bfloat16, comparison_cpu) { + // *********** complex ************* + EXPECT_TRUE(complex(1.0f) == complex(1.0f)); + EXPECT_TRUE(complex(1.0f, 2.0f) == complex(1.0f, 2.0f)); + EXPECT_FALSE(complex(-1.0f) == complex(-0.5f)); + EXPECT_TRUE(complex(1.0f) != complex(0.5f)); + EXPECT_FALSE(complex(-1.0f) != complex(-1.0f)); + EXPECT_TRUE(complex(1.0f) < complex(2.0f)); + EXPECT_FALSE(complex(-1.0f) < complex(-1.0f)); + EXPECT_TRUE(complex(1.0f) <= complex(1.0f)); + EXPECT_TRUE(complex(2.0f) > complex(1.0f)); + EXPECT_FALSE(complex(-2.0f) > complex(-2.0f)); + EXPECT_TRUE(complex(2.0f) >= complex(2.0f)); + + // *********** complex ************* + EXPECT_TRUE(complex(1.0) == complex(1.0)); + EXPECT_TRUE(complex(1.0, 2.0) == complex(1.0, 2.0)); + EXPECT_FALSE(complex(-1.0) == complex(-0.5f)); + EXPECT_TRUE(complex(1.0) != complex(0.5f)); + EXPECT_FALSE(complex(-1.0) != complex(-1.0)); + EXPECT_TRUE(complex(1.0) < complex(2.0)); + EXPECT_FALSE(complex(-1.0) < complex(-1.0)); + EXPECT_TRUE(complex(1.0) <= complex(1.0)); + EXPECT_TRUE(complex(2.0) > complex(1.0)); + EXPECT_FALSE(complex(-2.0) > complex(-2.0)); + EXPECT_TRUE(complex(2.0) >= complex(2.0)); +} + +TEST(complex, arithmetic_cpu) { + // *********** complex ************* + complex a = complex(1, 1) + complex(1, 1); + EXPECT_NEAR(a.real, 2, 0.001); + EXPECT_NEAR(a.imag, 2, 0.001); + + complex b = complex(-5, -5) + complex(5, 5); + EXPECT_EQ(b.real, 0); + EXPECT_EQ(b.imag, 0); + + complex c = + complex(0.33333f, 0.33333f) + complex(0.66667f, 0.66667f); + EXPECT_NEAR(c.real, 1.0f, 0.01); + EXPECT_NEAR(c.imag, 1.0f, 0.01); + + complex d = complex(3) - complex(5); + EXPECT_EQ(d.real, -2); + EXPECT_EQ(d.imag, 0); + + complex e = + complex(0.66667f, 0.66667f) - complex(0.33333f, 0.33333f); + EXPECT_NEAR(e.real, 0.33334f, 0.01); + EXPECT_NEAR(e.imag, 0.33334f, 0.01); + + complex f = complex(0.33f, 0.33f) * complex(0.2f, 0.2f); + EXPECT_NEAR(f.real, 0.0f, 0.01); + EXPECT_NEAR(f.imag, 0.132f, 0.01); + + complex g = complex(0.33f, 0.33f) / complex(0.2f, 0.2f); + EXPECT_NEAR(g.real, 1.65f, 0.01); + EXPECT_NEAR(g.imag, 0.0f, 0.01); + + complex h = -complex(0.33f, 0.33f); + EXPECT_NEAR(h.real, -0.33f, 0.01); + EXPECT_NEAR(h.imag, -0.33f, 0.01); + h = -complex(-0.33f, -0.33f); + EXPECT_NEAR(h.real, 0.33f, 0.01); + EXPECT_NEAR(h.imag, 0.33f, 0.01); + + complex i = complex(1.0, 1.0); + i += complex(2.0, 2.0); + EXPECT_NEAR(i.real, 3.0f, 0.01); + EXPECT_NEAR(i.imag, 3.0f, 0.01); + i -= complex(1.0, 1.0); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 2.0f, 0.01); + i *= complex(3, 2); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 10.0f, 0.01); + i /= complex(3, 2); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 2.0f, 0.01); + + // *********** complex ************* + complex a1 = complex(1, 1) + complex(1, 1); + EXPECT_NEAR(a1.real, 2, 0.001); + EXPECT_NEAR(a1.imag, 2, 0.001); + + complex b1 = complex(-5, -5) + complex(5, 5); + EXPECT_EQ(b1.real, 0); + EXPECT_EQ(b1.imag, 0); + + complex c1 = + complex(0.33333f, 0.33333f) + complex(0.66667f, 0.66667f); + EXPECT_NEAR(c1.real, 1.0f, 0.01); + EXPECT_NEAR(c1.imag, 1.0f, 0.01); + + complex d1 = complex(3) - complex(5); + EXPECT_EQ(d1.real, -2); + EXPECT_EQ(d1.imag, 0); + + complex e1 = + complex(0.66667f, 0.66667f) - complex(0.33333f, 0.33333f); + EXPECT_NEAR(e1.real, 0.33334f, 0.01); + EXPECT_NEAR(e1.imag, 0.33334f, 0.01); + + complex f1 = + complex(0.33f, 0.33f) * complex(0.2f, 0.2f); + EXPECT_NEAR(f1.real, 0.0f, 0.01); + EXPECT_NEAR(f1.imag, 0.132f, 0.01); + + complex g1 = + complex(0.33f, 0.33f) / complex(0.2f, 0.2f); + EXPECT_NEAR(g1.real, 1.65f, 0.01); + EXPECT_NEAR(g1.imag, 0.0f, 0.01); + + complex h1 = -complex(0.33f, 0.33f); + EXPECT_NEAR(h1.real, -0.33f, 0.01); + EXPECT_NEAR(h1.imag, -0.33f, 0.01); + h1 = -complex(-0.33f, -0.33f); + EXPECT_NEAR(h1.real, 0.33f, 0.01); + EXPECT_NEAR(h1.imag, 0.33f, 0.01); + + complex i1 = complex(1.0, 1.0); + i1 += complex(2.0, 2.0); + EXPECT_NEAR(i1.real, 3.0f, 0.01); + EXPECT_NEAR(i1.imag, 3.0f, 0.01); + i1 -= complex(1.0, 1.0); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 2.0f, 0.01); + i1 *= complex(3, 2); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 10.0f, 0.01); + i1 /= complex(3, 2); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 2.0f, 0.01); +} + +TEST(complex, print) { + complex a(1.0f); + std::cout << a << std::endl; + + complex b(1.0); + std::cout << b << std::endl; +} + +TEST(complex, isinf) { + // *********** complex ************* + complex a; + a.real = float(INFINITY); + EXPECT_EQ(std::isinf(a), true); + a.imag = float(INFINITY); + EXPECT_EQ(std::isinf(a), true); + + complex b = float(INFINITY); + EXPECT_EQ(std::isinf(b), true); + + complex c(float(INFINITY), 0); + EXPECT_EQ(std::isinf(c), true); + + // *********** complex ************* + complex a1; + a1.real = double(INFINITY); + EXPECT_EQ(std::isinf(a1), true); + a1.imag = double(INFINITY); + EXPECT_EQ(std::isinf(a1), true); + + complex b1 = double(INFINITY); + EXPECT_EQ(std::isinf(b1), true); + + complex c1(double(INFINITY), 0); + EXPECT_EQ(std::isinf(c1), true); +} + +TEST(complex, isnan) { + // *********** complex ************* + complex a; + a.real = float(NAN); + EXPECT_EQ(std::isnan(a), true); + a.imag = float(NAN); + EXPECT_EQ(std::isnan(a), true); + + complex b = float(NAN); + EXPECT_EQ(std::isnan(b), true); + + complex c(float(NAN), 0); + EXPECT_EQ(std::isnan(c), true); + + // *********** complex ************* + complex a1; + a1.real = double(NAN); + EXPECT_EQ(std::isnan(a1), true); + a1.imag = double(NAN); + EXPECT_EQ(std::isnan(a1), true); + + complex b1 = double(NAN); + EXPECT_EQ(std::isnan(b1), true); + + complex c1(double(NAN), 0); + EXPECT_EQ(std::isnan(c1), true); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/complex_test.cu b/paddle/fluid/platform/complex_test.cu new file mode 100644 index 0000000000..b46d1b7b27 --- /dev/null +++ b/paddle/fluid/platform/complex_test.cu @@ -0,0 +1,361 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/complex.h" + +#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/eigen_ext.h" +#include "paddle/fluid/platform/enforce.h" + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +namespace paddle { +namespace platform { + +TEST(complex, conversion_on_gpu) { + // *********** complex ************* + // thrust from and to complex + complex a(1.0f, 2.0f); + EXPECT_EQ(complex(thrust::complex(a)).real, 1.0); + EXPECT_EQ(complex(thrust::complex(a)).imag, 2.0); + + complex a1(1.0, 2.0); + EXPECT_EQ(complex(thrust::complex(a1)).real, 1.0); + EXPECT_EQ(complex(thrust::complex(a1)).imag, 2.0); + +#if defined(PADDLE_WITH_HIP) + EXPECT_EQ(hipFloatComplex(a).real(), 1.0); + EXPECT_EQ(hipFloatComplex(a).imag(), 2.0); + EXPECT_EQ(hipDoubleComplex(a).real(), 1.0); + EXPECT_EQ(hipDoubleComplex(a).imag(), 2.0); + + EXPECT_EQ(hipFloatComplex(a1).real(), 1.0); + EXPECT_EQ(hipFloatComplex(a1).imag(), 2.0); + EXPECT_EQ(hipDoubleComplex(a1).real(), 1.0); + EXPECT_EQ(hipDoubleComplex(a1).imag(), 2.0); +#else + EXPECT_EQ(cuCrealf(cuFloatComplex(a)), 1.0); + EXPECT_EQ(cuCimagf(cuFloatComplex(a)), 2.0); + EXPECT_EQ(cuCreal(cuDoubleComplex(a)), 1.0); + EXPECT_EQ(cuCimag(cuDoubleComplex(a)), 2.0); + + EXPECT_EQ(cuCrealf(cuFloatComplex(a1)), 1.0); + EXPECT_EQ(cuCimagf(cuFloatComplex(a1)), 2.0); + EXPECT_EQ(cuCreal(cuDoubleComplex(a1)), 1.0); + EXPECT_EQ(cuCimag(cuDoubleComplex(a1)), 2.0); +#endif + + EXPECT_EQ(complex().real, 0.0f); + EXPECT_EQ(complex().imag, 0.0f); + + EXPECT_EQ(complex(1.0f, 1.0f).real, 1.0f); + EXPECT_EQ(complex(1.0f, 1.0f).imag, 1.0f); + EXPECT_EQ(complex(0.0f, 1.0f).real, 0.0f); + EXPECT_EQ(complex(0.0f, 1.0f).imag, 1.0f); + + EXPECT_EQ(complex(1.0f).real, 1.0f); + EXPECT_EQ(complex(1.0f).imag, 0.0f); + + // int to complex + EXPECT_EQ(complex(1).real, 1.0f); + EXPECT_EQ(complex(0).real, 0.0f); + EXPECT_EQ(complex(2).real, 2.0f); + EXPECT_EQ(complex(-2).real, -2.0f); + + // bool to complex + EXPECT_EQ(complex(true).real, 1.0f); + EXPECT_EQ(complex(true).imag, 0.0f); + + // complex to complex + EXPECT_EQ(complex(complex(1.0, 2.0)).real, 1.0f); + EXPECT_EQ(complex(complex(1.0, 2.0)).imag, 2.0f); + + // std::complex to complex + EXPECT_EQ(complex(std::complex(1.0f, 2.0f)).real, 1.0f); + EXPECT_EQ(complex(std::complex(1.0f, 2.0f)).imag, 2.0f); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0f); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0f); + + // Assignment operator + complex c = 1.0f; + EXPECT_EQ(c.real, 1.0f); + EXPECT_EQ(c.imag, 0.0f); + c = complex(2.0, 2.0); + EXPECT_EQ(c.real, 2.0f); + EXPECT_EQ(c.imag, 2.0f); + + // Conversion operator + EXPECT_EQ(static_cast(complex(0.5f)), 0.5f); + EXPECT_NEAR(static_cast(complex(0.33333)), 0.33333, 0.01); + EXPECT_EQ(static_cast(complex(-1)), -1); + EXPECT_EQ(static_cast(complex(true)), true); + + // *********** complex ************* + // double to complex + EXPECT_EQ(complex().real, 0.0); + EXPECT_EQ(complex().imag, 0.0); + + EXPECT_EQ(complex(1.0, 1.0).real, 1.0); + EXPECT_EQ(complex(1.0, 1.0).imag, 1.0); + EXPECT_EQ(complex(0.0, 1.0).real, 0.0); + EXPECT_EQ(complex(0.0, 1.0).imag, 1.0); + + EXPECT_EQ(complex(1.0).real, 1.0); + EXPECT_EQ(complex(1.0).imag, 0.0); + + // int to complex + EXPECT_EQ(complex(1).real, 1.0); + EXPECT_EQ(complex(0).real, 0.0); + EXPECT_EQ(complex(2).real, 2.0); + EXPECT_EQ(complex(-2).real, -2.0); + + // bool to complex + EXPECT_EQ(complex(true).real, 1.0); + EXPECT_EQ(complex(true).imag, 0.0); + + // complex to complex + EXPECT_EQ(complex(complex(1.0f, 2.0f)).real, 1.0); + EXPECT_EQ(complex(complex(1.0f, 2.0f)).imag, 2.0); + + // std::complex to complex + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).real, 1.0); + EXPECT_EQ(complex(std::complex(1.0, 2.0)).imag, 2.0); + + // Assignment operator + complex c1 = 1.0; + EXPECT_EQ(c1.real, 1.0); + EXPECT_EQ(c1.imag, 0.0); + c1 = complex(2.0, 2.0); + EXPECT_EQ(c1.real, 2.0); + EXPECT_EQ(c1.imag, 2.0); + + // Conversion operator + EXPECT_EQ(static_cast(complex(0.5)), 0.5); + EXPECT_NEAR(static_cast(complex(0.33333)), 0.33333, 0.01); + EXPECT_EQ(static_cast(complex(-1)), -1); + EXPECT_EQ(static_cast(complex(true)), true); +} + +TEST(bfloat16, comparison_cpu) { + // *********** complex ************* + EXPECT_TRUE(complex(1.0f) == complex(1.0f)); + EXPECT_TRUE(complex(1.0f, 2.0f) == complex(1.0f, 2.0f)); + EXPECT_FALSE(complex(-1.0f) == complex(-0.5f)); + EXPECT_TRUE(complex(1.0f) != complex(0.5f)); + EXPECT_FALSE(complex(-1.0f) != complex(-1.0f)); + EXPECT_TRUE(complex(1.0f) < complex(2.0f)); + EXPECT_FALSE(complex(-1.0f) < complex(-1.0f)); + EXPECT_TRUE(complex(1.0f) <= complex(1.0f)); + EXPECT_TRUE(complex(2.0f) > complex(1.0f)); + EXPECT_FALSE(complex(-2.0f) > complex(-2.0f)); + EXPECT_TRUE(complex(2.0f) >= complex(2.0f)); + + // *********** complex ************* + EXPECT_TRUE(complex(1.0) == complex(1.0)); + EXPECT_TRUE(complex(1.0, 2.0) == complex(1.0, 2.0)); + EXPECT_FALSE(complex(-1.0) == complex(-0.5f)); + EXPECT_TRUE(complex(1.0) != complex(0.5f)); + EXPECT_FALSE(complex(-1.0) != complex(-1.0)); + EXPECT_TRUE(complex(1.0) < complex(2.0)); + EXPECT_FALSE(complex(-1.0) < complex(-1.0)); + EXPECT_TRUE(complex(1.0) <= complex(1.0)); + EXPECT_TRUE(complex(2.0) > complex(1.0)); + EXPECT_FALSE(complex(-2.0) > complex(-2.0)); + EXPECT_TRUE(complex(2.0) >= complex(2.0)); +} + +TEST(complex, arithmetic_cpu) { + // *********** complex ************* + complex a = complex(1, 1) + complex(1, 1); + EXPECT_NEAR(a.real, 2, 0.001); + EXPECT_NEAR(a.imag, 2, 0.001); + + complex b = complex(-5, -5) + complex(5, 5); + EXPECT_EQ(b.real, 0); + EXPECT_EQ(b.imag, 0); + + complex c = + complex(0.33333f, 0.33333f) + complex(0.66667f, 0.66667f); + EXPECT_NEAR(c.real, 1.0f, 0.01); + EXPECT_NEAR(c.imag, 1.0f, 0.01); + + complex d = complex(3) - complex(5); + EXPECT_EQ(d.real, -2); + EXPECT_EQ(d.imag, 0); + + complex e = + complex(0.66667f, 0.66667f) - complex(0.33333f, 0.33333f); + EXPECT_NEAR(e.real, 0.33334f, 0.01); + EXPECT_NEAR(e.imag, 0.33334f, 0.01); + + complex f = complex(0.33f, 0.33f) * complex(0.2f, 0.2f); + EXPECT_NEAR(f.real, 0.0f, 0.01); + EXPECT_NEAR(f.imag, 0.132f, 0.01); + + complex g = complex(0.33f, 0.33f) / complex(0.2f, 0.2f); + EXPECT_NEAR(g.real, 1.65f, 0.01); + EXPECT_NEAR(g.imag, 0.0f, 0.01); + + complex h = -complex(0.33f, 0.33f); + EXPECT_NEAR(h.real, -0.33f, 0.01); + EXPECT_NEAR(h.imag, -0.33f, 0.01); + h = -complex(-0.33f, -0.33f); + EXPECT_NEAR(h.real, 0.33f, 0.01); + EXPECT_NEAR(h.imag, 0.33f, 0.01); + + complex i = complex(1.0, 1.0); + i += complex(2.0, 2.0); + EXPECT_NEAR(i.real, 3.0f, 0.01); + EXPECT_NEAR(i.imag, 3.0f, 0.01); + i -= complex(1.0, 1.0); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 2.0f, 0.01); + i *= complex(3, 2); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 10.0f, 0.01); + i /= complex(3, 2); + EXPECT_NEAR(i.real, 2.0f, 0.01); + EXPECT_NEAR(i.imag, 2.0f, 0.01); + + // *********** complex ************* + complex a1 = complex(1, 1) + complex(1, 1); + EXPECT_NEAR(a1.real, 2, 0.001); + EXPECT_NEAR(a1.imag, 2, 0.001); + + complex b1 = complex(-5, -5) + complex(5, 5); + EXPECT_EQ(b1.real, 0); + EXPECT_EQ(b1.imag, 0); + + complex c1 = + complex(0.33333f, 0.33333f) + complex(0.66667f, 0.66667f); + EXPECT_NEAR(c1.real, 1.0f, 0.01); + EXPECT_NEAR(c1.imag, 1.0f, 0.01); + + complex d1 = complex(3) - complex(5); + EXPECT_EQ(d1.real, -2); + EXPECT_EQ(d1.imag, 0); + + complex e1 = + complex(0.66667f, 0.66667f) - complex(0.33333f, 0.33333f); + EXPECT_NEAR(e1.real, 0.33334f, 0.01); + EXPECT_NEAR(e1.imag, 0.33334f, 0.01); + + complex f1 = + complex(0.33f, 0.33f) * complex(0.2f, 0.2f); + EXPECT_NEAR(f1.real, 0.0f, 0.01); + EXPECT_NEAR(f1.imag, 0.132f, 0.01); + + complex g1 = + complex(0.33f, 0.33f) / complex(0.2f, 0.2f); + EXPECT_NEAR(g1.real, 1.65f, 0.01); + EXPECT_NEAR(g1.imag, 0.0f, 0.01); + + complex h1 = -complex(0.33f, 0.33f); + EXPECT_NEAR(h1.real, -0.33f, 0.01); + EXPECT_NEAR(h1.imag, -0.33f, 0.01); + h1 = -complex(-0.33f, -0.33f); + EXPECT_NEAR(h1.real, 0.33f, 0.01); + EXPECT_NEAR(h1.imag, 0.33f, 0.01); + + complex i1 = complex(1.0, 1.0); + i1 += complex(2.0, 2.0); + EXPECT_NEAR(i1.real, 3.0f, 0.01); + EXPECT_NEAR(i1.imag, 3.0f, 0.01); + i1 -= complex(1.0, 1.0); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 2.0f, 0.01); + i1 *= complex(3, 2); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 10.0f, 0.01); + i1 /= complex(3, 2); + EXPECT_NEAR(i1.real, 2.0f, 0.01); + EXPECT_NEAR(i1.imag, 2.0f, 0.01); +} + +TEST(complex, print) { + complex a(1.0f); + std::cout << a << std::endl; + + complex b(1.0); + std::cout << b << std::endl; +} + +TEST(complex, isinf) { + // *********** complex ************* + complex a; + a.real = float(INFINITY); + EXPECT_EQ(std::isinf(a), true); + a.imag = float(INFINITY); + EXPECT_EQ(std::isinf(a), true); + + complex b = float(INFINITY); + EXPECT_EQ(std::isinf(b), true); + + complex c(float(INFINITY), 0); + EXPECT_EQ(std::isinf(c), true); + + // *********** complex ************* + complex a1; + a1.real = double(INFINITY); + EXPECT_EQ(std::isinf(a1), true); + a1.imag = double(INFINITY); + EXPECT_EQ(std::isinf(a1), true); + + complex b1 = double(INFINITY); + EXPECT_EQ(std::isinf(b1), true); + + complex c1(double(INFINITY), 0); + EXPECT_EQ(std::isinf(c1), true); +} + +TEST(complex, isnan) { + // *********** complex ************* + complex a; + a.real = float(NAN); + EXPECT_EQ(std::isnan(a), true); + a.imag = float(NAN); + EXPECT_EQ(std::isnan(a), true); + + complex b = float(NAN); + EXPECT_EQ(std::isnan(b), true); + + complex c(float(NAN), 0); + EXPECT_EQ(std::isnan(c), true); + + // *********** complex ************* + complex a1; + a1.real = double(NAN); + EXPECT_EQ(std::isnan(a1), true); + a1.imag = double(NAN); + EXPECT_EQ(std::isnan(a1), true); + + complex b1 = double(NAN); + EXPECT_EQ(std::isnan(b1), true); + + complex c1(double(NAN), 0); + EXPECT_EQ(std::isnan(c1), true); +} + +} // namespace platform +} // namespace paddle +#endif \ No newline at end of file diff --git a/paddle/fluid/platform/eigen_ext.h b/paddle/fluid/platform/eigen_ext.h index 0db4cc71b1..4eea87e909 100644 --- a/paddle/fluid/platform/eigen_ext.h +++ b/paddle/fluid/platform/eigen_ext.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" @@ -27,6 +28,8 @@ namespace Eigen { using complex64 = paddle::platform::complex64; using complex128 = paddle::platform::complex128; using float16 = paddle::platform::float16; +template +using complex = paddle::platform::complex; template struct NumTraits; @@ -105,6 +108,50 @@ struct NumTraits : GenericNumTraits> { static inline int digits10() { return NumTraits::digits10(); } }; +template <> +struct NumTraits> : GenericNumTraits> { + typedef float Real; + typedef typename NumTraits::Literal Literal; + enum { + IsComplex = 1, + RequireInitialization = NumTraits::RequireInitialization, + ReadCost = 2 * NumTraits::ReadCost, + AddCost = 2 * NumTraits::AddCost, + MulCost = 4 * NumTraits::MulCost + 2 * NumTraits::AddCost + }; + + EIGEN_DEVICE_FUNC + static inline Real epsilon() { return NumTraits::epsilon(); } + EIGEN_DEVICE_FUNC + static inline Real dummy_precision() { + return NumTraits::dummy_precision(); + } + EIGEN_DEVICE_FUNC + static inline int digits10() { return NumTraits::digits10(); } +}; + +template <> +struct NumTraits> : GenericNumTraits> { + typedef double Real; + typedef typename NumTraits::Literal Literal; + enum { + IsComplex = 1, + RequireInitialization = NumTraits::RequireInitialization, + ReadCost = 2 * NumTraits::ReadCost, + AddCost = 2 * NumTraits::AddCost, + MulCost = 4 * NumTraits::MulCost + 2 * NumTraits::AddCost + }; + + EIGEN_DEVICE_FUNC + static inline Real epsilon() { return NumTraits::epsilon(); } + EIGEN_DEVICE_FUNC + static inline Real dummy_precision() { + return NumTraits::dummy_precision(); + } + EIGEN_DEVICE_FUNC + static inline int digits10() { return NumTraits::digits10(); } +}; + template <> struct NumTraits : GenericNumTraits { enum { @@ -354,6 +401,138 @@ HOSTDEVICE inline double abs(const complex128& a) { return paddle::platform::abs(a); } +//////////// complex methods ///////////// + +template <> +HOSTDEVICE inline bool(isnan)(const complex& a) { + return (paddle::platform::isnan)(a); +} + +template <> +HOSTDEVICE inline bool(isinf)(const complex& a) { + return (paddle::platform::isinf)(a); +} + +template <> +HOSTDEVICE inline bool(isfinite)(const complex& a) { + return (paddle::platform::isfinite)(a); +} + +template <> +HOSTDEVICE inline complex exp(const complex& a) { + float com = ::expf(a.real); + float res_real = com * ::cosf(a.imag); + float res_imag = com * ::sinf(a.imag); + return complex(res_real, res_imag); +} + +template <> +HOSTDEVICE inline complex log(const complex& a) { + return paddle::platform::log(a); +} + +template <> +HOSTDEVICE inline complex tanh(const complex& a) { + return paddle::platform::tanh(a); +} + +template <> +HOSTDEVICE inline complex sqrt(const complex& a) { + return paddle::platform::sqrt(a); +} + +template <> +HOSTDEVICE inline complex ceil(const complex& a) { + return complex(::ceilf(a.real), ::ceilf(a.imag)); +} + +template <> +HOSTDEVICE inline complex floor(const complex& a) { + return complex(::floorf(a.real), ::floor(a.imag)); +} + +template <> +HOSTDEVICE inline complex round(const complex& a) { + return complex(::roundf(a.real), ::roundf(a.imag)); +} + +template <> +HOSTDEVICE inline complex pow(const complex& a, + const complex& b) { + return paddle::platform::pow(a, b); +} + +template <> +HOSTDEVICE inline float abs(const complex& a) { + return paddle::platform::abs(a); +} + +//////////// complex methods ///////////// + +template <> +HOSTDEVICE inline bool(isnan)(const complex& a) { + return (paddle::platform::isnan)(a); +} + +template <> +HOSTDEVICE inline bool(isinf)(const complex& a) { + return (paddle::platform::isinf)(a); +} + +template <> +HOSTDEVICE inline bool(isfinite)(const complex& a) { + return (paddle::platform::isfinite)(a); +} + +template <> +HOSTDEVICE inline complex exp(const complex& a) { + double com = ::expf(a.real); + double res_real = com * ::cosf(a.imag); + double res_imag = com * ::sinf(a.imag); + return complex(res_real, res_imag); +} + +template <> +HOSTDEVICE inline complex log(const complex& a) { + return paddle::platform::log(a); +} + +template <> +HOSTDEVICE inline complex tanh(const complex& a) { + return paddle::platform::tanh(a); +} + +template <> +HOSTDEVICE inline complex sqrt(const complex& a) { + return paddle::platform::sqrt(a); +} + +template <> +HOSTDEVICE inline complex ceil(const complex& a) { + return complex(::ceilf(a.real), ::ceilf(a.imag)); +} + +template <> +HOSTDEVICE inline complex floor(const complex& a) { + return complex(::floorf(a.real), ::floor(a.imag)); +} + +template <> +HOSTDEVICE inline complex round(const complex& a) { + return complex(::roundf(a.real), ::roundf(a.imag)); +} + +template <> +HOSTDEVICE inline complex pow(const complex& a, + const complex& b) { + return paddle::platform::pow(a, b); +} + +template <> +HOSTDEVICE inline double abs(const complex& a) { + return paddle::platform::abs(a); +} + //////////// float16 methods ///////////// template <> diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 416361d06a..2095b49974 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -122,6 +122,43 @@ struct npy_format_descriptor { static constexpr auto name = _("complext128"); }; +// we register paddle::platform::complex64 as numpy.complex64. +template <> +struct npy_format_descriptor> { + static py::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64); + return reinterpret_borrow(ptr); + } + + static std::string format() { + // Note: "F" represents complex64. + // Details at: + // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx + // for k, v in np.sctypeDict.iteritems(): + // print '{0:14s} : {1:40s}'.format(str(k), v) + return "F"; + } + static constexpr auto name = _("complext64"); +}; + +template <> +struct npy_format_descriptor> { + static py::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128); + return reinterpret_borrow(ptr); + } + + static std::string format() { + // Note: "D" represents complex128. + // Details at: + // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx + // for k, v in np.sctypeDict.iteritems(): + // print '{0:14s} : {1:40s}'.format(str(k), v) + return "D"; + } + static constexpr auto name = _("complext128"); +}; + } // namespace detail } // namespace pybind11 @@ -170,6 +207,8 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex64); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex128); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); @@ -192,6 +231,10 @@ inline std::string TensorDTypeToPyDTypeStr( return "F"; \ } else if (std::is_same::value) { \ return "D"; \ + } else if (std::is_same>::value) { \ + return "F"; \ + } else if (std::is_same>::value) { \ + return "D"; \ } else { \ constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker::kValue; \ PADDLE_ENFORCE_EQ( \ @@ -373,6 +416,14 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); + } else if (py::isinstance>>( + array)) { + SetTensorFromPyArrayT, P>( + self, array, place, zero_copy); + } else if (py::isinstance>>( + array)) { + SetTensorFromPyArrayT, P>( + self, array, place, zero_copy); } else if (py::isinstance>(array)) { // since there is still no support for bfloat16 in NumPy, // uint16 is used for casting bfloat16 diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index c51080e4e3..cb0581d671 100644 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -421,6 +421,7 @@ CPU_PARALLEL_JOB = [ 'buffered_allocator_test', 'broadcast_op_test', 'bfloat16_test', + 'complex_test', 'beam_search_decode_op_test', 'auto_growth_best_fit_allocator_test', 'assign_op_test', -- GitLab