未验证 提交 738bf20e 编写于 作者: C chentianyu03 提交者: GitHub

Add complex template type (#32857)

* add complex template file

* add numtraits for complex template

* add complex template type register

* modify specify template of complex

* modify specify template of complex

* modify specify template of complex

* modify specify template of complex

* make TensorCheckerVisitor support complex type

* fix operator= error

* add complex template

* add complex template type

* add complex template type to pyarray transform

* add complex template type to pyarray transform

* remove complex type for dlpack register

* set dlpack supprot complex type

* set dlpack supprot complex type

* set dlpack supprot complex type

* remove explict for complex constructor

* add complex unit test file
上级 8854786a
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/eigen_ext.h"
......@@ -30,6 +31,8 @@ struct bfloat16;
struct complex128;
struct complex64;
struct float16;
template <typename T>
struct complex;
} // namespace platform
} // namespace paddle
......@@ -61,6 +64,10 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
......@@ -69,6 +76,10 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
......
......@@ -163,6 +163,11 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex128 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
float > : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex < \
double > : omp_out += omp_in)
#endif
template <typename T>
......@@ -268,12 +273,69 @@ void CheckNanInf<paddle::platform::complex128>(
op_type));
}
}
template <>
void CheckNanInf<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
float real_sum = 0.0f;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
float imag_sum = 0.0f;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
template <>
void CheckNanInf<paddle::platform::complex<double>>>
(const paddle::platform::complex<double>* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
double real_sum = 0.0;
#pragma omp parallel for reduction(+ : real_sum)
for (size_t i = 0; i < numel; ++i) {
real_sum += (value[i].real - value[i].real);
}
double imag_sum = 0.0;
#pragma omp parallel for reduction(+ : imag_sum)
for (size_t i = 0; i < numel; ++i) {
imag_sum += (value[i].imag - value[i].imag);
}
if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) ||
std::isinf(imag_sum)) {
// hot fix for compile failed in gcc4.8
// here also need print detail info of nan or inf later
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
#endif
template <>
template <typename T>
void TensorCheckerVisitor<platform::CPUDeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
// use env strategy control in future, -1=print_all.
int print_num = 3;
CheckNanInf(tensor_.data<T>(), tensor_.numel(), print_num, op_type_,
......
......@@ -123,7 +123,11 @@ __global__ void CheckNanInfKernel(const T* value, const size_t numel,
template <>
template <typename T>
void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type*)
const {
int print_num = 3;
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
......
......@@ -46,8 +46,12 @@ struct TensorCheckerVisitor {
}
template <typename T>
void apply(typename std::enable_if<std::is_floating_point<T>::value>::type* =
0) const;
void apply(
typename std::enable_if<
std::is_floating_point<T>::value ||
std::is_same<T, ::paddle::platform::complex<float>>::value ||
std::is_same<T, ::paddle::platform::complex<double>>::value>::type* =
0) const;
std::string op_type_;
std::string var_name_;
......
......@@ -28,9 +28,19 @@ namespace internal {
template <typename T>
static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype;
if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value ||
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
// The current dlpack library version is v0.2, and does not define
// kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set
// dtype.code to 5U directly here. After the dlpack library version being
// upgraded to v0.4, it should be written as follow.
// dtype.code = kDLComplex;
dtype.code = 5U;
} else if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) {
dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) {
dtype.code = kDLUInt;
......
......@@ -28,6 +28,13 @@ namespace framework {
namespace { // NOLINT
template <typename T>
constexpr uint8_t GetDLDataTypeCode() {
if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value ||
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value) {
return static_cast<uint8_t>(5);
}
return std::is_same<platform::float16, T>::value ||
std::is_floating_point<T>::value
? static_cast<uint8_t>(kDLFloat)
......
......@@ -65,16 +65,18 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex64); \
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex<float>); \
macro(::paddle::platform::complex<double>); \
macro(::paddle::platform::complex64); \
macro(::paddle::platform::complex128)
......@@ -47,6 +47,10 @@ template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::CPUDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CPUDeviceContext,
platform::complex<double>>;
#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
......@@ -59,6 +63,10 @@ template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::XPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::XPUDeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::XPUDeviceContext,
platform::complex<double>>;
#endif
#define DEFINE_CPU_TRANS(RANK) \
......@@ -74,6 +82,10 @@ template struct SetConstant<platform::XPUDeviceContext, platform::complex128>;
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<float>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, \
platform::complex<double>, RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex64, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex128, \
......@@ -130,6 +142,8 @@ DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(platform::complex64);
DEFINE_CPU_TRANS_NORMAL(platform::complex128);
DEFINE_CPU_TRANS_NORMAL(platform::complex<float>);
DEFINE_CPU_TRANS_NORMAL(platform::complex<double>);
struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value)
......
......@@ -43,6 +43,10 @@ template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex64>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
template struct SetConstant<platform::CUDADeviceContext,
platform::complex<float>>;
template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
......@@ -52,6 +56,10 @@ template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<float>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, \
paddle::platform::complex<double>, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex64, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex128, RANK>;
......@@ -145,6 +153,8 @@ DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(complex64);
DEFINE_GPU_TRANS_NORMAL(complex128);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<float>);
DEFINE_GPU_TRANS_NORMAL(paddle::platform::complex<double>);
struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context,
......
......@@ -187,10 +187,12 @@ endif()
cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
cc_test(bfloat16_test SRCS bfloat16_test.cc DEPS lod_tensor)
cc_test(complex_test SRCS complex_test.cc DEPS lod_tensor)
IF(WITH_GPU)
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
nv_test(bfloat16_gpu_test SRCS bfloat16_test.cu DEPS lod_tensor)
nv_test(complex_gpu_test SRCS complex_test.cu DEPS lod_tensor)
nv_test(test_limit_gpu_memory SRCS test_limit_gpu_memory.cu DEPS gpu_info flags)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
ENDIF()
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <complex>
#include <cstring>
#include <iostream>
#include <limits>
#ifdef PADDLE_WITH_CUDA
#include <cuComplex.h>
#include <thrust/complex.h>
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_HIP
#include <hip/hip_complex.h>
#include <thrust/complex.h> // NOLINT
#endif
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x) __declspec(align(x))
#endif
#if (defined(__CUDACC__) || defined(__HIPCC__))
#define HOSTDEVICE __host__ __device__
#define DEVICE __device__
#define HOST __host__
#else
#define HOSTDEVICE
#define DEVICE
#define HOST
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// todo
#define PADDLE_WITH_CUDA_OR_HIP_COMPLEX
#endif
namespace paddle {
namespace platform {
template <typename T>
struct PADDLE_ALIGN(sizeof(T) * 2) complex {
public:
T real;
T imag;
complex() = default;
complex(const complex<T>& o) = default;
complex& operator=(const complex<T>& o) = default;
complex(complex<T>&& o) = default;
complex& operator=(complex<T>&& o) = default;
~complex() = default;
HOSTDEVICE complex(T real, T imag) : real(real), imag(imag) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T1>
HOSTDEVICE inline explicit complex(const thrust::complex<T1>& c) {
real = c.real();
imag = c.imag();
}
template <typename T1>
HOSTDEVICE inline explicit operator thrust::complex<T1>() const {
return thrust::complex<T1>(real, imag);
}
#ifdef PADDLE_WITH_HIP
HOSTDEVICE inline explicit operator hipFloatComplex() const {
return make_hipFloatComplex(real, imag);
}
HOSTDEVICE inline explicit operator hipDoubleComplex() const {
return make_hipDoubleComplex(real, imag);
}
#else
HOSTDEVICE inline explicit operator cuFloatComplex() const {
return make_cuFloatComplex(real, imag);
}
HOSTDEVICE inline explicit operator cuDoubleComplex() const {
return make_cuDoubleComplex(real, imag);
}
#endif
#endif
template <typename T1,
typename std::enable_if<std::is_floating_point<T1>::value ||
std::is_integral<T1>::value,
int>::type = 0>
HOSTDEVICE complex(const T1& val) {
real = static_cast<T>(val);
imag = static_cast<T>(0.0);
}
template <typename T1 = T>
HOSTDEVICE explicit complex(
const std::enable_if_t<std::is_same<T1, float>::value, complex<double>>&
val) {
real = val.real;
imag = val.imag;
}
template <typename T1 = T>
HOSTDEVICE explicit complex(
const std::enable_if_t<std::is_same<T1, double>::value, complex<float>>&
val) {
real = val.real;
imag = val.imag;
}
template <typename T1>
HOSTDEVICE inline explicit operator std::complex<T1>() const {
return static_cast<std::complex<T1>>(std::complex<T>(real, imag));
}
template <typename T1>
HOSTDEVICE complex(const std::complex<T1>& val)
: real(val.real()), imag(val.imag()) {}
template <typename T1,
typename std::enable_if<std::is_floating_point<T1>::value ||
std::is_integral<T1>::value,
int>::type = 0>
HOSTDEVICE inline complex& operator=(const T1& val) {
real = static_cast<T>(val);
imag = static_cast<T>(0.0);
return *this;
}
HOSTDEVICE inline explicit operator bool() const {
return static_cast<bool>(this->real) || static_cast<bool>(this->imag);
}
HOSTDEVICE inline explicit operator int8_t() const {
return static_cast<int8_t>(this->real);
}
HOSTDEVICE inline explicit operator uint8_t() const {
return static_cast<uint8_t>(this->real);
}
HOSTDEVICE inline explicit operator int16_t() const {
return static_cast<int16_t>(this->real);
}
HOSTDEVICE inline explicit operator uint16_t() const {
return static_cast<uint16_t>(this->real);
}
HOSTDEVICE inline explicit operator int32_t() const {
return static_cast<int32_t>(this->real);
}
HOSTDEVICE inline explicit operator uint32_t() const {
return static_cast<uint32_t>(this->real);
}
HOSTDEVICE inline explicit operator int64_t() const {
return static_cast<int64_t>(this->real);
}
HOSTDEVICE inline explicit operator uint64_t() const {
return static_cast<uint64_t>(this->real);
}
HOSTDEVICE inline explicit operator float() const {
return static_cast<float>(this->real);
}
HOSTDEVICE inline explicit operator double() const {
return static_cast<double>(this->real);
}
};
template <typename T>
HOSTDEVICE inline complex<T> operator+(const complex<T>& a,
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::complex<T>(a) + thrust::complex<T>(b));
#else
return complex<T>(a.real + b.real, a.imag + b.imag);
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> operator-(const complex<T>& a,
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::complex<T>(a) - thrust::complex<T>(b));
#else
return complex<T>(a.real - b.real, a.imag - b.imag);
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> operator*(const complex<T>& a,
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::complex<T>(a) * thrust::complex<T>(b));
#else
return complex<T>(a.real * b.real - a.imag * b.imag,
a.imag * b.real + b.imag * a.real);
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> operator/(const complex<T>& a,
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::complex<T>(a) / thrust::complex<T>(b));
#else
T denominator = b.real * b.real + b.imag * b.imag;
return complex<T>((a.real * b.real + a.imag * b.imag) / denominator,
(a.imag * b.real - a.real * b.imag) / denominator);
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> operator-(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(-thrust::complex<T>(a.real, a.imag));
#else
complex<T> res;
res.real = -a.real;
res.imag = -a.imag;
return res;
#endif
}
template <typename T>
HOSTDEVICE inline complex<T>& operator+=(complex<T>& a, // NOLINT
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
a = complex<T>(thrust::complex<T>(a.real, a.imag) +=
thrust::complex<T>(b.real, b.imag));
return a;
#else
a.real += b.real;
a.imag += b.imag;
return a;
#endif
}
template <typename T>
HOSTDEVICE inline complex<T>& operator-=(complex<T>& a, // NOLINT
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
a = complex<T>(thrust::complex<T>(a.real, a.imag) -=
thrust::complex<T>(b.real, b.imag));
return a;
#else
a.real -= b.real;
a.imag -= b.imag;
return a;
#endif
}
template <typename T>
HOSTDEVICE inline complex<T>& operator*=(complex<T>& a, // NOLINT
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
a = complex<T>(thrust::complex<T>(a.real, a.imag) *=
thrust::complex<T>(b.real, b.imag));
return a;
#else
a.real = a.real * b.real - a.imag * b.imag;
a.imag = a.imag * b.real + b.imag * a.real;
return a;
#endif
}
template <typename T>
HOSTDEVICE inline complex<T>& operator/=(complex<T>& a, // NOLINT
const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
a = complex<T>(thrust::complex<T>(a.real, a.imag) /=
thrust::complex<T>(b.real, b.imag));
return a;
#else
T denominator = b.real * b.real + b.imag * b.imag;
a.real = (a.real * b.real + a.imag * b.imag) / denominator;
a.imag = (a.imag * b.real - a.real * b.imag) / denominator;
return a;
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> raw_uint16_to_complex64(uint16_t a) {
complex<T> res;
res.real = a;
res.imag = 0.0;
return res;
}
template <typename T>
HOSTDEVICE inline bool operator==(const complex<T>& a, const complex<T>& b) {
return a.real == b.real && a.imag == b.imag;
}
template <typename T>
HOSTDEVICE inline bool operator!=(const complex<T>& a, const complex<T>& b) {
return a.real != b.real || a.imag != b.imag;
}
template <typename T>
HOSTDEVICE inline bool operator<(const complex<T>& a, const complex<T>& b) {
return a.real < b.real;
}
template <typename T>
HOSTDEVICE inline bool operator<=(const complex<T>& a, const complex<T>& b) {
return a.real <= b.real;
}
template <typename T>
HOSTDEVICE inline bool operator>(const complex<T>& a, const complex<T>& b) {
return a.real > b.real;
}
template <typename T>
HOSTDEVICE inline bool operator>=(const complex<T>& a, const complex<T>& b) {
return a.real >= b.real;
}
template <typename T>
HOSTDEVICE inline complex<T> max(const complex<T>& a, const complex<T>& b) {
return (a.real >= b.real) ? a : b;
}
template <typename T>
HOSTDEVICE inline complex<T> min(const complex<T>& a, const complex<T>& b) {
return (a.real < b.real) ? a : b;
}
template <typename T>
HOSTDEVICE inline bool(isnan)(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return ::isnan(a.real) || ::isnan(a.imag);
#else
return std::isnan(a.real) || std::isnan(a.imag);
#endif
}
template <typename T>
HOSTDEVICE inline bool isinf(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return ::isinf(a.real) || ::isinf(a.imag);
#else
return std::isinf(a.real) || std::isinf(a.imag);
#endif
}
template <typename T>
HOSTDEVICE inline bool isfinite(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return ::isfinite(a.real) || ::isfinite(a.imag);
#else
return std::isfinite(a.real) || std::isfinite(a.imag);
#endif
}
template <typename T>
HOSTDEVICE inline T abs(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return thrust::abs(thrust::complex<T>(a));
#else
return std::abs(std::complex<T>(a));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> pow(const complex<T>& a, const complex<T>& b) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::pow(thrust::complex<T>(a), thrust::complex<T>(b)));
#else
return complex<T>(std::pow(std::complex<T>(a), std::complex<T>(b)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> sqrt(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::sqrt(thrust::complex<T>(a)));
#else
return complex<T>(std::sqrt(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::tanh(thrust::complex<T>(a)));
#else
return complex<T>(std::tanh(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::log(thrust::complex<T>(a)));
#else
return complex<T>(std::log(std::complex<T>(a)));
#endif
}
template <typename T>
inline std::ostream& operator<<(std::ostream& os, const complex<T>& a) {
os << "real:" << a.real << " imag:" << a.imag;
return os;
}
} // namespace platform
} // namespace paddle
namespace std {
template <typename T>
struct is_pod<paddle::platform::complex<T>> {
static const bool value = true;
};
template <typename T>
struct is_floating_point<paddle::platform::complex<T>>
: std::integral_constant<bool, false> {};
template <typename T>
struct is_signed<paddle::platform::complex<T>> {
static const bool value = false;
};
template <typename T>
struct is_unsigned<paddle::platform::complex<T>> {
static const bool value = false;
};
template <typename T>
inline bool isnan(const paddle::platform::complex<T>& a) {
return paddle::platform::isnan(a);
}
template <typename T>
inline bool isinf(const paddle::platform::complex<T>& a) {
return paddle::platform::isinf(a);
}
template <typename T>
struct numeric_limits<paddle::platform::complex<T>> {
static const bool is_specialized = false;
static const bool is_signed = false;
static const bool is_integer = false;
static const bool is_exact = false;
static const bool has_infinity = false;
static const bool has_quiet_NaN = false;
static const bool has_signaling_NaN = false;
static const float_denorm_style has_denorm = denorm_absent;
static const bool has_denorm_loss = false;
static const std::float_round_style round_style = std::round_toward_zero;
static const bool is_iec559 = false;
static const bool is_bounded = false;
static const bool is_modulo = false;
static const int digits = 0;
static const int digits10 = 0;
static const int max_digits10 = 0;
static const int radix = 0;
static const int min_exponent = 0;
static const int min_exponent10 = 0;
static const int max_exponent = 0;
static const int max_exponent10 = 0;
static const bool traps = false;
static const bool tinyness_before = false;
static paddle::platform::complex<T> min() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> lowest() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> max() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> epsilon() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> round_error() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> infinity() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> quiet_NaN() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> signaling_NaN() {
return paddle::platform::complex<T>(0.0, 0.0);
}
static paddle::platform::complex<T> denorm_min() {
return paddle::platform::complex<T>(0.0, 0.0);
}
};
} // namespace std
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/complex.h"
#include <complex>
#include "paddle/fluid/platform/eigen_ext.h"
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
TEST(complex, conversion_cpu) {
// *********** complex<float> *************
// float to complex<float>
EXPECT_EQ(complex<float>().real, 0.0f);
EXPECT_EQ(complex<float>().imag, 0.0f);
EXPECT_EQ(complex<float>(1.0f, 1.0f).real, 1.0f);
EXPECT_EQ(complex<float>(1.0f, 1.0f).imag, 1.0f);
EXPECT_EQ(complex<float>(0.0f, 1.0f).real, 0.0f);
EXPECT_EQ(complex<float>(0.0f, 1.0f).imag, 1.0f);
EXPECT_EQ(complex<float>(1.0f).real, 1.0f);
EXPECT_EQ(complex<float>(1.0f).imag, 0.0f);
// int to complex<float>
EXPECT_EQ(complex<float>(1).real, 1.0f);
EXPECT_EQ(complex<float>(0).real, 0.0f);
EXPECT_EQ(complex<float>(2).real, 2.0f);
EXPECT_EQ(complex<float>(-2).real, -2.0f);
// bool to complex
EXPECT_EQ(complex<float>(true).real, 1.0f);
EXPECT_EQ(complex<float>(true).imag, 0.0f);
// complex<double> to complex<float>
EXPECT_EQ(complex<float>(complex<double>(1.0, 2.0)).real, 1.0f);
EXPECT_EQ(complex<float>(complex<double>(1.0, 2.0)).imag, 2.0f);
// std::complex<float> to complex<float>
EXPECT_EQ(complex<float>(std::complex<float>(1.0f, 2.0f)).real, 1.0f);
EXPECT_EQ(complex<float>(std::complex<float>(1.0f, 2.0f)).imag, 2.0f);
EXPECT_EQ(complex<float>(std::complex<double>(1.0, 2.0)).real, 1.0f);
EXPECT_EQ(complex<float>(std::complex<double>(1.0, 2.0)).imag, 2.0f);
// Assignment operator
complex<float> c = 1.0f;
EXPECT_EQ(c.real, 1.0f);
EXPECT_EQ(c.imag, 0.0f);
c = complex<float>(2.0, 2.0);
EXPECT_EQ(c.real, 2.0f);
EXPECT_EQ(c.imag, 2.0f);
// Conversion operator
EXPECT_EQ(static_cast<float>(complex<float>(0.5f)), 0.5f);
EXPECT_NEAR(static_cast<double>(complex<float>(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(complex<float>(-1)), -1);
EXPECT_EQ(static_cast<bool>(complex<float>(true)), true);
// *********** complex<double> *************
// double to complex<double>
EXPECT_EQ(complex<double>().real, 0.0);
EXPECT_EQ(complex<double>().imag, 0.0);
EXPECT_EQ(complex<double>(1.0, 1.0).real, 1.0);
EXPECT_EQ(complex<double>(1.0, 1.0).imag, 1.0);
EXPECT_EQ(complex<double>(0.0, 1.0).real, 0.0);
EXPECT_EQ(complex<double>(0.0, 1.0).imag, 1.0);
EXPECT_EQ(complex<double>(1.0).real, 1.0);
EXPECT_EQ(complex<double>(1.0).imag, 0.0);
// int to complex<double>
EXPECT_EQ(complex<double>(1).real, 1.0);
EXPECT_EQ(complex<double>(0).real, 0.0);
EXPECT_EQ(complex<double>(2).real, 2.0);
EXPECT_EQ(complex<double>(-2).real, -2.0);
// bool to complex
EXPECT_EQ(complex<double>(true).real, 1.0);
EXPECT_EQ(complex<double>(true).imag, 0.0);
// complex<float> to complex<double>
EXPECT_EQ(complex<double>(complex<float>(1.0f, 2.0f)).real, 1.0);
EXPECT_EQ(complex<double>(complex<float>(1.0f, 2.0f)).imag, 2.0);
// std::complex<float> to complex<double>
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).real, 1.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).imag, 2.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).real, 1.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).imag, 2.0);
// Assignment operator
complex<double> c1 = 1.0;
EXPECT_EQ(c1.real, 1.0);
EXPECT_EQ(c1.imag, 0.0);
c1 = complex<double>(2.0, 2.0);
EXPECT_EQ(c1.real, 2.0);
EXPECT_EQ(c1.imag, 2.0);
// Conversion operator
EXPECT_EQ(static_cast<double>(complex<double>(0.5)), 0.5);
EXPECT_NEAR(static_cast<double>(complex<double>(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(complex<double>(-1)), -1);
EXPECT_EQ(static_cast<bool>(complex<double>(true)), true);
}
TEST(bfloat16, comparison_cpu) {
// *********** complex<float> *************
EXPECT_TRUE(complex<float>(1.0f) == complex<float>(1.0f));
EXPECT_TRUE(complex<float>(1.0f, 2.0f) == complex<float>(1.0f, 2.0f));
EXPECT_FALSE(complex<float>(-1.0f) == complex<float>(-0.5f));
EXPECT_TRUE(complex<float>(1.0f) != complex<float>(0.5f));
EXPECT_FALSE(complex<float>(-1.0f) != complex<float>(-1.0f));
EXPECT_TRUE(complex<float>(1.0f) < complex<float>(2.0f));
EXPECT_FALSE(complex<float>(-1.0f) < complex<float>(-1.0f));
EXPECT_TRUE(complex<float>(1.0f) <= complex<float>(1.0f));
EXPECT_TRUE(complex<float>(2.0f) > complex<float>(1.0f));
EXPECT_FALSE(complex<float>(-2.0f) > complex<float>(-2.0f));
EXPECT_TRUE(complex<float>(2.0f) >= complex<float>(2.0f));
// *********** complex<double> *************
EXPECT_TRUE(complex<double>(1.0) == complex<double>(1.0));
EXPECT_TRUE(complex<double>(1.0, 2.0) == complex<double>(1.0, 2.0));
EXPECT_FALSE(complex<double>(-1.0) == complex<double>(-0.5f));
EXPECT_TRUE(complex<double>(1.0) != complex<double>(0.5f));
EXPECT_FALSE(complex<double>(-1.0) != complex<double>(-1.0));
EXPECT_TRUE(complex<double>(1.0) < complex<double>(2.0));
EXPECT_FALSE(complex<double>(-1.0) < complex<double>(-1.0));
EXPECT_TRUE(complex<double>(1.0) <= complex<double>(1.0));
EXPECT_TRUE(complex<double>(2.0) > complex<double>(1.0));
EXPECT_FALSE(complex<double>(-2.0) > complex<double>(-2.0));
EXPECT_TRUE(complex<double>(2.0) >= complex<double>(2.0));
}
TEST(complex, arithmetic_cpu) {
// *********** complex<float> *************
complex<float> a = complex<float>(1, 1) + complex<float>(1, 1);
EXPECT_NEAR(a.real, 2, 0.001);
EXPECT_NEAR(a.imag, 2, 0.001);
complex<float> b = complex<float>(-5, -5) + complex<float>(5, 5);
EXPECT_EQ(b.real, 0);
EXPECT_EQ(b.imag, 0);
complex<float> c =
complex<float>(0.33333f, 0.33333f) + complex<float>(0.66667f, 0.66667f);
EXPECT_NEAR(c.real, 1.0f, 0.01);
EXPECT_NEAR(c.imag, 1.0f, 0.01);
complex<float> d = complex<float>(3) - complex<float>(5);
EXPECT_EQ(d.real, -2);
EXPECT_EQ(d.imag, 0);
complex<float> e =
complex<float>(0.66667f, 0.66667f) - complex<float>(0.33333f, 0.33333f);
EXPECT_NEAR(e.real, 0.33334f, 0.01);
EXPECT_NEAR(e.imag, 0.33334f, 0.01);
complex<float> f = complex<float>(0.33f, 0.33f) * complex<float>(0.2f, 0.2f);
EXPECT_NEAR(f.real, 0.0f, 0.01);
EXPECT_NEAR(f.imag, 0.132f, 0.01);
complex<float> g = complex<float>(0.33f, 0.33f) / complex<float>(0.2f, 0.2f);
EXPECT_NEAR(g.real, 1.65f, 0.01);
EXPECT_NEAR(g.imag, 0.0f, 0.01);
complex<float> h = -complex<float>(0.33f, 0.33f);
EXPECT_NEAR(h.real, -0.33f, 0.01);
EXPECT_NEAR(h.imag, -0.33f, 0.01);
h = -complex<float>(-0.33f, -0.33f);
EXPECT_NEAR(h.real, 0.33f, 0.01);
EXPECT_NEAR(h.imag, 0.33f, 0.01);
complex<float> i = complex<float>(1.0, 1.0);
i += complex<float>(2.0, 2.0);
EXPECT_NEAR(i.real, 3.0f, 0.01);
EXPECT_NEAR(i.imag, 3.0f, 0.01);
i -= complex<float>(1.0, 1.0);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 2.0f, 0.01);
i *= complex<float>(3, 2);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 10.0f, 0.01);
i /= complex<float>(3, 2);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 2.0f, 0.01);
// *********** complex<double> *************
complex<double> a1 = complex<double>(1, 1) + complex<double>(1, 1);
EXPECT_NEAR(a1.real, 2, 0.001);
EXPECT_NEAR(a1.imag, 2, 0.001);
complex<double> b1 = complex<double>(-5, -5) + complex<double>(5, 5);
EXPECT_EQ(b1.real, 0);
EXPECT_EQ(b1.imag, 0);
complex<double> c1 =
complex<double>(0.33333f, 0.33333f) + complex<double>(0.66667f, 0.66667f);
EXPECT_NEAR(c1.real, 1.0f, 0.01);
EXPECT_NEAR(c1.imag, 1.0f, 0.01);
complex<double> d1 = complex<double>(3) - complex<double>(5);
EXPECT_EQ(d1.real, -2);
EXPECT_EQ(d1.imag, 0);
complex<double> e1 =
complex<double>(0.66667f, 0.66667f) - complex<double>(0.33333f, 0.33333f);
EXPECT_NEAR(e1.real, 0.33334f, 0.01);
EXPECT_NEAR(e1.imag, 0.33334f, 0.01);
complex<double> f1 =
complex<double>(0.33f, 0.33f) * complex<double>(0.2f, 0.2f);
EXPECT_NEAR(f1.real, 0.0f, 0.01);
EXPECT_NEAR(f1.imag, 0.132f, 0.01);
complex<double> g1 =
complex<double>(0.33f, 0.33f) / complex<double>(0.2f, 0.2f);
EXPECT_NEAR(g1.real, 1.65f, 0.01);
EXPECT_NEAR(g1.imag, 0.0f, 0.01);
complex<double> h1 = -complex<double>(0.33f, 0.33f);
EXPECT_NEAR(h1.real, -0.33f, 0.01);
EXPECT_NEAR(h1.imag, -0.33f, 0.01);
h1 = -complex<double>(-0.33f, -0.33f);
EXPECT_NEAR(h1.real, 0.33f, 0.01);
EXPECT_NEAR(h1.imag, 0.33f, 0.01);
complex<double> i1 = complex<double>(1.0, 1.0);
i1 += complex<double>(2.0, 2.0);
EXPECT_NEAR(i1.real, 3.0f, 0.01);
EXPECT_NEAR(i1.imag, 3.0f, 0.01);
i1 -= complex<double>(1.0, 1.0);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 2.0f, 0.01);
i1 *= complex<double>(3, 2);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 10.0f, 0.01);
i1 /= complex<double>(3, 2);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 2.0f, 0.01);
}
TEST(complex, print) {
complex<float> a(1.0f);
std::cout << a << std::endl;
complex<double> b(1.0);
std::cout << b << std::endl;
}
TEST(complex, isinf) {
// *********** complex<float> *************
complex<float> a;
a.real = float(INFINITY);
EXPECT_EQ(std::isinf(a), true);
a.imag = float(INFINITY);
EXPECT_EQ(std::isinf(a), true);
complex<float> b = float(INFINITY);
EXPECT_EQ(std::isinf(b), true);
complex<float> c(float(INFINITY), 0);
EXPECT_EQ(std::isinf(c), true);
// *********** complex<double> *************
complex<double> a1;
a1.real = double(INFINITY);
EXPECT_EQ(std::isinf(a1), true);
a1.imag = double(INFINITY);
EXPECT_EQ(std::isinf(a1), true);
complex<double> b1 = double(INFINITY);
EXPECT_EQ(std::isinf(b1), true);
complex<double> c1(double(INFINITY), 0);
EXPECT_EQ(std::isinf(c1), true);
}
TEST(complex, isnan) {
// *********** complex<float> *************
complex<float> a;
a.real = float(NAN);
EXPECT_EQ(std::isnan(a), true);
a.imag = float(NAN);
EXPECT_EQ(std::isnan(a), true);
complex<float> b = float(NAN);
EXPECT_EQ(std::isnan(b), true);
complex<float> c(float(NAN), 0);
EXPECT_EQ(std::isnan(c), true);
// *********** complex<double> *************
complex<double> a1;
a1.real = double(NAN);
EXPECT_EQ(std::isnan(a1), true);
a1.imag = double(NAN);
EXPECT_EQ(std::isnan(a1), true);
complex<double> b1 = double(NAN);
EXPECT_EQ(std::isnan(b1), true);
complex<double> c1(double(NAN), 0);
EXPECT_EQ(std::isnan(c1), true);
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/complex.h"
#define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <thrust/complex.h>
#include <bitset>
#include <iostream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace paddle {
namespace platform {
TEST(complex, conversion_on_gpu) {
// *********** complex<float> *************
// thrust<float> from and to complex<float>
complex<float> a(1.0f, 2.0f);
EXPECT_EQ(complex<float>(thrust::complex<float>(a)).real, 1.0);
EXPECT_EQ(complex<float>(thrust::complex<float>(a)).imag, 2.0);
complex<double> a1(1.0, 2.0);
EXPECT_EQ(complex<double>(thrust::complex<double>(a1)).real, 1.0);
EXPECT_EQ(complex<double>(thrust::complex<double>(a1)).imag, 2.0);
#if defined(PADDLE_WITH_HIP)
EXPECT_EQ(hipFloatComplex(a).real(), 1.0);
EXPECT_EQ(hipFloatComplex(a).imag(), 2.0);
EXPECT_EQ(hipDoubleComplex(a).real(), 1.0);
EXPECT_EQ(hipDoubleComplex(a).imag(), 2.0);
EXPECT_EQ(hipFloatComplex(a1).real(), 1.0);
EXPECT_EQ(hipFloatComplex(a1).imag(), 2.0);
EXPECT_EQ(hipDoubleComplex(a1).real(), 1.0);
EXPECT_EQ(hipDoubleComplex(a1).imag(), 2.0);
#else
EXPECT_EQ(cuCrealf(cuFloatComplex(a)), 1.0);
EXPECT_EQ(cuCimagf(cuFloatComplex(a)), 2.0);
EXPECT_EQ(cuCreal(cuDoubleComplex(a)), 1.0);
EXPECT_EQ(cuCimag(cuDoubleComplex(a)), 2.0);
EXPECT_EQ(cuCrealf(cuFloatComplex(a1)), 1.0);
EXPECT_EQ(cuCimagf(cuFloatComplex(a1)), 2.0);
EXPECT_EQ(cuCreal(cuDoubleComplex(a1)), 1.0);
EXPECT_EQ(cuCimag(cuDoubleComplex(a1)), 2.0);
#endif
EXPECT_EQ(complex<float>().real, 0.0f);
EXPECT_EQ(complex<float>().imag, 0.0f);
EXPECT_EQ(complex<float>(1.0f, 1.0f).real, 1.0f);
EXPECT_EQ(complex<float>(1.0f, 1.0f).imag, 1.0f);
EXPECT_EQ(complex<float>(0.0f, 1.0f).real, 0.0f);
EXPECT_EQ(complex<float>(0.0f, 1.0f).imag, 1.0f);
EXPECT_EQ(complex<float>(1.0f).real, 1.0f);
EXPECT_EQ(complex<float>(1.0f).imag, 0.0f);
// int to complex<float>
EXPECT_EQ(complex<float>(1).real, 1.0f);
EXPECT_EQ(complex<float>(0).real, 0.0f);
EXPECT_EQ(complex<float>(2).real, 2.0f);
EXPECT_EQ(complex<float>(-2).real, -2.0f);
// bool to complex
EXPECT_EQ(complex<float>(true).real, 1.0f);
EXPECT_EQ(complex<float>(true).imag, 0.0f);
// complex<double> to complex<float>
EXPECT_EQ(complex<float>(complex<double>(1.0, 2.0)).real, 1.0f);
EXPECT_EQ(complex<float>(complex<double>(1.0, 2.0)).imag, 2.0f);
// std::complex<float> to complex<float>
EXPECT_EQ(complex<float>(std::complex<float>(1.0f, 2.0f)).real, 1.0f);
EXPECT_EQ(complex<float>(std::complex<float>(1.0f, 2.0f)).imag, 2.0f);
EXPECT_EQ(complex<float>(std::complex<double>(1.0, 2.0)).real, 1.0f);
EXPECT_EQ(complex<float>(std::complex<double>(1.0, 2.0)).imag, 2.0f);
// Assignment operator
complex<float> c = 1.0f;
EXPECT_EQ(c.real, 1.0f);
EXPECT_EQ(c.imag, 0.0f);
c = complex<float>(2.0, 2.0);
EXPECT_EQ(c.real, 2.0f);
EXPECT_EQ(c.imag, 2.0f);
// Conversion operator
EXPECT_EQ(static_cast<float>(complex<float>(0.5f)), 0.5f);
EXPECT_NEAR(static_cast<double>(complex<float>(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(complex<float>(-1)), -1);
EXPECT_EQ(static_cast<bool>(complex<float>(true)), true);
// *********** complex<double> *************
// double to complex<double>
EXPECT_EQ(complex<double>().real, 0.0);
EXPECT_EQ(complex<double>().imag, 0.0);
EXPECT_EQ(complex<double>(1.0, 1.0).real, 1.0);
EXPECT_EQ(complex<double>(1.0, 1.0).imag, 1.0);
EXPECT_EQ(complex<double>(0.0, 1.0).real, 0.0);
EXPECT_EQ(complex<double>(0.0, 1.0).imag, 1.0);
EXPECT_EQ(complex<double>(1.0).real, 1.0);
EXPECT_EQ(complex<double>(1.0).imag, 0.0);
// int to complex<double>
EXPECT_EQ(complex<double>(1).real, 1.0);
EXPECT_EQ(complex<double>(0).real, 0.0);
EXPECT_EQ(complex<double>(2).real, 2.0);
EXPECT_EQ(complex<double>(-2).real, -2.0);
// bool to complex
EXPECT_EQ(complex<double>(true).real, 1.0);
EXPECT_EQ(complex<double>(true).imag, 0.0);
// complex<float> to complex<double>
EXPECT_EQ(complex<double>(complex<float>(1.0f, 2.0f)).real, 1.0);
EXPECT_EQ(complex<double>(complex<float>(1.0f, 2.0f)).imag, 2.0);
// std::complex<float> to complex<double>
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).real, 1.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).imag, 2.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).real, 1.0);
EXPECT_EQ(complex<double>(std::complex<double>(1.0, 2.0)).imag, 2.0);
// Assignment operator
complex<double> c1 = 1.0;
EXPECT_EQ(c1.real, 1.0);
EXPECT_EQ(c1.imag, 0.0);
c1 = complex<double>(2.0, 2.0);
EXPECT_EQ(c1.real, 2.0);
EXPECT_EQ(c1.imag, 2.0);
// Conversion operator
EXPECT_EQ(static_cast<double>(complex<double>(0.5)), 0.5);
EXPECT_NEAR(static_cast<double>(complex<double>(0.33333)), 0.33333, 0.01);
EXPECT_EQ(static_cast<int>(complex<double>(-1)), -1);
EXPECT_EQ(static_cast<bool>(complex<double>(true)), true);
}
TEST(bfloat16, comparison_cpu) {
// *********** complex<float> *************
EXPECT_TRUE(complex<float>(1.0f) == complex<float>(1.0f));
EXPECT_TRUE(complex<float>(1.0f, 2.0f) == complex<float>(1.0f, 2.0f));
EXPECT_FALSE(complex<float>(-1.0f) == complex<float>(-0.5f));
EXPECT_TRUE(complex<float>(1.0f) != complex<float>(0.5f));
EXPECT_FALSE(complex<float>(-1.0f) != complex<float>(-1.0f));
EXPECT_TRUE(complex<float>(1.0f) < complex<float>(2.0f));
EXPECT_FALSE(complex<float>(-1.0f) < complex<float>(-1.0f));
EXPECT_TRUE(complex<float>(1.0f) <= complex<float>(1.0f));
EXPECT_TRUE(complex<float>(2.0f) > complex<float>(1.0f));
EXPECT_FALSE(complex<float>(-2.0f) > complex<float>(-2.0f));
EXPECT_TRUE(complex<float>(2.0f) >= complex<float>(2.0f));
// *********** complex<double> *************
EXPECT_TRUE(complex<double>(1.0) == complex<double>(1.0));
EXPECT_TRUE(complex<double>(1.0, 2.0) == complex<double>(1.0, 2.0));
EXPECT_FALSE(complex<double>(-1.0) == complex<double>(-0.5f));
EXPECT_TRUE(complex<double>(1.0) != complex<double>(0.5f));
EXPECT_FALSE(complex<double>(-1.0) != complex<double>(-1.0));
EXPECT_TRUE(complex<double>(1.0) < complex<double>(2.0));
EXPECT_FALSE(complex<double>(-1.0) < complex<double>(-1.0));
EXPECT_TRUE(complex<double>(1.0) <= complex<double>(1.0));
EXPECT_TRUE(complex<double>(2.0) > complex<double>(1.0));
EXPECT_FALSE(complex<double>(-2.0) > complex<double>(-2.0));
EXPECT_TRUE(complex<double>(2.0) >= complex<double>(2.0));
}
TEST(complex, arithmetic_cpu) {
// *********** complex<float> *************
complex<float> a = complex<float>(1, 1) + complex<float>(1, 1);
EXPECT_NEAR(a.real, 2, 0.001);
EXPECT_NEAR(a.imag, 2, 0.001);
complex<float> b = complex<float>(-5, -5) + complex<float>(5, 5);
EXPECT_EQ(b.real, 0);
EXPECT_EQ(b.imag, 0);
complex<float> c =
complex<float>(0.33333f, 0.33333f) + complex<float>(0.66667f, 0.66667f);
EXPECT_NEAR(c.real, 1.0f, 0.01);
EXPECT_NEAR(c.imag, 1.0f, 0.01);
complex<float> d = complex<float>(3) - complex<float>(5);
EXPECT_EQ(d.real, -2);
EXPECT_EQ(d.imag, 0);
complex<float> e =
complex<float>(0.66667f, 0.66667f) - complex<float>(0.33333f, 0.33333f);
EXPECT_NEAR(e.real, 0.33334f, 0.01);
EXPECT_NEAR(e.imag, 0.33334f, 0.01);
complex<float> f = complex<float>(0.33f, 0.33f) * complex<float>(0.2f, 0.2f);
EXPECT_NEAR(f.real, 0.0f, 0.01);
EXPECT_NEAR(f.imag, 0.132f, 0.01);
complex<float> g = complex<float>(0.33f, 0.33f) / complex<float>(0.2f, 0.2f);
EXPECT_NEAR(g.real, 1.65f, 0.01);
EXPECT_NEAR(g.imag, 0.0f, 0.01);
complex<float> h = -complex<float>(0.33f, 0.33f);
EXPECT_NEAR(h.real, -0.33f, 0.01);
EXPECT_NEAR(h.imag, -0.33f, 0.01);
h = -complex<float>(-0.33f, -0.33f);
EXPECT_NEAR(h.real, 0.33f, 0.01);
EXPECT_NEAR(h.imag, 0.33f, 0.01);
complex<float> i = complex<float>(1.0, 1.0);
i += complex<float>(2.0, 2.0);
EXPECT_NEAR(i.real, 3.0f, 0.01);
EXPECT_NEAR(i.imag, 3.0f, 0.01);
i -= complex<float>(1.0, 1.0);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 2.0f, 0.01);
i *= complex<float>(3, 2);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 10.0f, 0.01);
i /= complex<float>(3, 2);
EXPECT_NEAR(i.real, 2.0f, 0.01);
EXPECT_NEAR(i.imag, 2.0f, 0.01);
// *********** complex<double> *************
complex<double> a1 = complex<double>(1, 1) + complex<double>(1, 1);
EXPECT_NEAR(a1.real, 2, 0.001);
EXPECT_NEAR(a1.imag, 2, 0.001);
complex<double> b1 = complex<double>(-5, -5) + complex<double>(5, 5);
EXPECT_EQ(b1.real, 0);
EXPECT_EQ(b1.imag, 0);
complex<double> c1 =
complex<double>(0.33333f, 0.33333f) + complex<double>(0.66667f, 0.66667f);
EXPECT_NEAR(c1.real, 1.0f, 0.01);
EXPECT_NEAR(c1.imag, 1.0f, 0.01);
complex<double> d1 = complex<double>(3) - complex<double>(5);
EXPECT_EQ(d1.real, -2);
EXPECT_EQ(d1.imag, 0);
complex<double> e1 =
complex<double>(0.66667f, 0.66667f) - complex<double>(0.33333f, 0.33333f);
EXPECT_NEAR(e1.real, 0.33334f, 0.01);
EXPECT_NEAR(e1.imag, 0.33334f, 0.01);
complex<double> f1 =
complex<double>(0.33f, 0.33f) * complex<double>(0.2f, 0.2f);
EXPECT_NEAR(f1.real, 0.0f, 0.01);
EXPECT_NEAR(f1.imag, 0.132f, 0.01);
complex<double> g1 =
complex<double>(0.33f, 0.33f) / complex<double>(0.2f, 0.2f);
EXPECT_NEAR(g1.real, 1.65f, 0.01);
EXPECT_NEAR(g1.imag, 0.0f, 0.01);
complex<double> h1 = -complex<double>(0.33f, 0.33f);
EXPECT_NEAR(h1.real, -0.33f, 0.01);
EXPECT_NEAR(h1.imag, -0.33f, 0.01);
h1 = -complex<double>(-0.33f, -0.33f);
EXPECT_NEAR(h1.real, 0.33f, 0.01);
EXPECT_NEAR(h1.imag, 0.33f, 0.01);
complex<double> i1 = complex<double>(1.0, 1.0);
i1 += complex<double>(2.0, 2.0);
EXPECT_NEAR(i1.real, 3.0f, 0.01);
EXPECT_NEAR(i1.imag, 3.0f, 0.01);
i1 -= complex<double>(1.0, 1.0);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 2.0f, 0.01);
i1 *= complex<double>(3, 2);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 10.0f, 0.01);
i1 /= complex<double>(3, 2);
EXPECT_NEAR(i1.real, 2.0f, 0.01);
EXPECT_NEAR(i1.imag, 2.0f, 0.01);
}
TEST(complex, print) {
complex<float> a(1.0f);
std::cout << a << std::endl;
complex<double> b(1.0);
std::cout << b << std::endl;
}
TEST(complex, isinf) {
// *********** complex<float> *************
complex<float> a;
a.real = float(INFINITY);
EXPECT_EQ(std::isinf(a), true);
a.imag = float(INFINITY);
EXPECT_EQ(std::isinf(a), true);
complex<float> b = float(INFINITY);
EXPECT_EQ(std::isinf(b), true);
complex<float> c(float(INFINITY), 0);
EXPECT_EQ(std::isinf(c), true);
// *********** complex<double> *************
complex<double> a1;
a1.real = double(INFINITY);
EXPECT_EQ(std::isinf(a1), true);
a1.imag = double(INFINITY);
EXPECT_EQ(std::isinf(a1), true);
complex<double> b1 = double(INFINITY);
EXPECT_EQ(std::isinf(b1), true);
complex<double> c1(double(INFINITY), 0);
EXPECT_EQ(std::isinf(c1), true);
}
TEST(complex, isnan) {
// *********** complex<float> *************
complex<float> a;
a.real = float(NAN);
EXPECT_EQ(std::isnan(a), true);
a.imag = float(NAN);
EXPECT_EQ(std::isnan(a), true);
complex<float> b = float(NAN);
EXPECT_EQ(std::isnan(b), true);
complex<float> c(float(NAN), 0);
EXPECT_EQ(std::isnan(c), true);
// *********** complex<double> *************
complex<double> a1;
a1.real = double(NAN);
EXPECT_EQ(std::isnan(a1), true);
a1.imag = double(NAN);
EXPECT_EQ(std::isnan(a1), true);
complex<double> b1 = double(NAN);
EXPECT_EQ(std::isnan(b1), true);
complex<double> c1(double(NAN), 0);
EXPECT_EQ(std::isnan(c1), true);
}
} // namespace platform
} // namespace paddle
#endif
\ No newline at end of file
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
......@@ -27,6 +28,8 @@ namespace Eigen {
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
using float16 = paddle::platform::float16;
template <typename T>
using complex = paddle::platform::complex<T>;
template <typename T>
struct NumTraits;
......@@ -105,6 +108,50 @@ struct NumTraits<complex128> : GenericNumTraits<std::complex<double>> {
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
template <>
struct NumTraits<complex<float>> : GenericNumTraits<std::complex<float>> {
typedef float Real;
typedef typename NumTraits<float>::Literal Literal;
enum {
IsComplex = 1,
RequireInitialization = NumTraits<float>::RequireInitialization,
ReadCost = 2 * NumTraits<float>::ReadCost,
AddCost = 2 * NumTraits<Real>::AddCost,
MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
};
EIGEN_DEVICE_FUNC
static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
EIGEN_DEVICE_FUNC
static inline Real dummy_precision() {
return NumTraits<Real>::dummy_precision();
}
EIGEN_DEVICE_FUNC
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
template <>
struct NumTraits<complex<double>> : GenericNumTraits<std::complex<double>> {
typedef double Real;
typedef typename NumTraits<double>::Literal Literal;
enum {
IsComplex = 1,
RequireInitialization = NumTraits<double>::RequireInitialization,
ReadCost = 2 * NumTraits<double>::ReadCost,
AddCost = 2 * NumTraits<Real>::AddCost,
MulCost = 4 * NumTraits<Real>::MulCost + 2 * NumTraits<Real>::AddCost
};
EIGEN_DEVICE_FUNC
static inline Real epsilon() { return NumTraits<Real>::epsilon(); }
EIGEN_DEVICE_FUNC
static inline Real dummy_precision() {
return NumTraits<Real>::dummy_precision();
}
EIGEN_DEVICE_FUNC
static inline int digits10() { return NumTraits<Real>::digits10(); }
};
template <>
struct NumTraits<float16> : GenericNumTraits<float16> {
enum {
......@@ -354,6 +401,138 @@ HOSTDEVICE inline double abs(const complex128& a) {
return paddle::platform::abs(a);
}
//////////// complex<float> methods /////////////
template <>
HOSTDEVICE inline bool(isnan)(const complex<float>& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const complex<float>& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const complex<float>& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline complex<float> exp(const complex<float>& a) {
float com = ::expf(a.real);
float res_real = com * ::cosf(a.imag);
float res_imag = com * ::sinf(a.imag);
return complex<float>(res_real, res_imag);
}
template <>
HOSTDEVICE inline complex<float> log(const complex<float>& a) {
return paddle::platform::log(a);
}
template <>
HOSTDEVICE inline complex<float> tanh(const complex<float>& a) {
return paddle::platform::tanh(a);
}
template <>
HOSTDEVICE inline complex<float> sqrt(const complex<float>& a) {
return paddle::platform::sqrt(a);
}
template <>
HOSTDEVICE inline complex<float> ceil(const complex<float>& a) {
return complex<float>(::ceilf(a.real), ::ceilf(a.imag));
}
template <>
HOSTDEVICE inline complex<float> floor(const complex<float>& a) {
return complex<float>(::floorf(a.real), ::floor(a.imag));
}
template <>
HOSTDEVICE inline complex<float> round(const complex<float>& a) {
return complex<float>(::roundf(a.real), ::roundf(a.imag));
}
template <>
HOSTDEVICE inline complex<float> pow(const complex<float>& a,
const complex<float>& b) {
return paddle::platform::pow(a, b);
}
template <>
HOSTDEVICE inline float abs(const complex<float>& a) {
return paddle::platform::abs(a);
}
//////////// complex<double> methods /////////////
template <>
HOSTDEVICE inline bool(isnan)(const complex<double>& a) {
return (paddle::platform::isnan)(a);
}
template <>
HOSTDEVICE inline bool(isinf)(const complex<double>& a) {
return (paddle::platform::isinf)(a);
}
template <>
HOSTDEVICE inline bool(isfinite)(const complex<double>& a) {
return (paddle::platform::isfinite)(a);
}
template <>
HOSTDEVICE inline complex<double> exp(const complex<double>& a) {
double com = ::expf(a.real);
double res_real = com * ::cosf(a.imag);
double res_imag = com * ::sinf(a.imag);
return complex<double>(res_real, res_imag);
}
template <>
HOSTDEVICE inline complex<double> log(const complex<double>& a) {
return paddle::platform::log(a);
}
template <>
HOSTDEVICE inline complex<double> tanh(const complex<double>& a) {
return paddle::platform::tanh(a);
}
template <>
HOSTDEVICE inline complex<double> sqrt(const complex<double>& a) {
return paddle::platform::sqrt(a);
}
template <>
HOSTDEVICE inline complex<double> ceil(const complex<double>& a) {
return complex<double>(::ceilf(a.real), ::ceilf(a.imag));
}
template <>
HOSTDEVICE inline complex<double> floor(const complex<double>& a) {
return complex<double>(::floorf(a.real), ::floor(a.imag));
}
template <>
HOSTDEVICE inline complex<double> round(const complex<double>& a) {
return complex<double>(::roundf(a.real), ::roundf(a.imag));
}
template <>
HOSTDEVICE inline complex<double> pow(const complex<double>& a,
const complex<double>& b) {
return paddle::platform::pow(a, b);
}
template <>
HOSTDEVICE inline double abs(const complex<double>& a) {
return paddle::platform::abs(a);
}
//////////// float16 methods /////////////
template <>
......
......@@ -122,6 +122,43 @@ struct npy_format_descriptor<paddle::platform::complex128> {
static constexpr auto name = _("complext128");
};
// we register paddle::platform::complex64 as numpy.complex64.
template <>
struct npy_format_descriptor<paddle::platform::complex<float>> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "F" represents complex64.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "F";
}
static constexpr auto name = _("complext64");
};
template <>
struct npy_format_descriptor<paddle::platform::complex<double>> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "D" represents complex128.
// Details at:
// https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
// for k, v in np.sctypeDict.iteritems():
// print '{0:14s} : {1:40s}'.format(str(k), v)
return "D";
}
static constexpr auto name = _("complext128");
};
} // namespace detail
} // namespace pybind11
......@@ -170,6 +207,8 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex64);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex128);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<float>);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<double>);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
......@@ -192,6 +231,10 @@ inline std::string TensorDTypeToPyDTypeStr(
return "F"; \
} else if (std::is_same<T, platform::complex128>::value) { \
return "D"; \
} else if (std::is_same<T, platform::complex<float>>::value) { \
return "F"; \
} else if (std::is_same<T, platform::complex<double>>::value) { \
return "D"; \
} else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE_EQ( \
......@@ -373,6 +416,14 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
} else if (py::isinstance<py::array_t<paddle::platform::complex128>>(array)) {
SetTensorFromPyArrayT<paddle::platform::complex128, P>(self, array, place,
zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex<float>>>(
array)) {
SetTensorFromPyArrayT<paddle::platform::complex<float>, P>(
self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::complex<double>>>(
array)) {
SetTensorFromPyArrayT<paddle::platform::complex<double>, P>(
self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// since there is still no support for bfloat16 in NumPy,
// uint16 is used for casting bfloat16
......
......@@ -421,6 +421,7 @@ CPU_PARALLEL_JOB = [
'buffered_allocator_test',
'broadcast_op_test',
'bfloat16_test',
'complex_test',
'beam_search_decode_op_test',
'auto_growth_best_fit_allocator_test',
'assign_op_test',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册