diff --git a/paddle/phi/kernels/cpu/isclose_kernel.cc b/paddle/phi/kernels/cpu/isclose_kernel.cc index dca21494b3ee951c73177eb5bce628bc9a6cfc2a..33457921df61e2c536a41db8fbdfb8dd6ebaba91 100644 --- a/paddle/phi/kernels/cpu/isclose_kernel.cc +++ b/paddle/phi/kernels/cpu/isclose_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/isclose_kernel_impl.h" -PD_REGISTER_KERNEL( - isclose, CPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} +PD_REGISTER_KERNEL(isclose, + CPU, + ALL_LAYOUT, + phi::IscloseKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/isclose_kernel.cu b/paddle/phi/kernels/gpu/isclose_kernel.cu index cfae8d0bbda29a529434f0c29c8dc85ca6131c6b..1242269242e0bfa24477258e2b7b9fd2ecfddf71 100644 --- a/paddle/phi/kernels/gpu/isclose_kernel.cu +++ b/paddle/phi/kernels/gpu/isclose_kernel.cu @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(isclose, phi::IscloseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/isclose_kernel_impl.h b/paddle/phi/kernels/impl/isclose_kernel_impl.h index de59cb0c32ca132e52ab3b0b708c7c4793bba3ff..93dfb7790b4abdaff1bc73760827351ef4a72804 100644 --- a/paddle/phi/kernels/impl/isclose_kernel_impl.h +++ b/paddle/phi/kernels/impl/isclose_kernel_impl.h @@ -19,6 +19,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" @@ -86,6 +87,40 @@ struct IscloseFunctor { } }; +template +struct IscloseFunctor> { + void operator()(const phi::CPUContext& ctx, + const DenseTensor& in, + const DenseTensor& other, + const double rtol, + const double atol, + bool equal_nan, + DenseTensor* output) { + auto* in_a = in.data>(); + auto* in_b = other.data>(); + auto* out_data = ctx.template Alloc(output); + auto num = in.numel(); + // *out_data = true; + for (int i = 0; i < num; i++) { + out_data[i] = true; + } + for (int i = 0; i < num; i++) { + const phi::dtype::complex a = in_a[i], b = in_b[i]; + bool val; + if (std::isnan(a) || std::isnan(b)) { + val = equal_nan && std::isnan(a) == std::isnan(b); + } else { + T left = abs(a - b); + T right = atol + rtol * abs(b); + T diff = abs(left - right); + val = a == b || left <= right || diff <= 1e-15; + // *out_data &= val; + out_data[i] = val; + } + } + } +}; + #if defined(__NVCC__) || defined(__HIPCC__) template __global__ void IscloseCUDAKernel(const T* in_data, @@ -113,7 +148,59 @@ __global__ void IscloseCUDAKernel(const T* in_data, // if (!val) *out_data = false; } } +template <> +__global__ void IscloseCUDAKernel>( + const phi::dtype::complex* in_data, + const phi::dtype::complex* other_data, + const double rtol, + const double atol, + bool equal_nan, + int num, + bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const phi::dtype::complex a = in_data[i]; + const phi::dtype::complex b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + float left = abs(a - b); + float right = atol + rtol * abs(b); + float diff = abs(left - right); + val = a == b || left <= right || diff <= 1e-15; + } + out_data[i] = val; + // if (!val) *out_data = false; + } +} +template <> +__global__ void IscloseCUDAKernel>( + const phi::dtype::complex* in_data, + const phi::dtype::complex* other_data, + const double rtol, + const double atol, + bool equal_nan, + int num, + bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const phi::dtype::complex a = in_data[i]; + const phi::dtype::complex b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + double left = abs(a - b); + double right = atol + rtol * abs(b); + double diff = abs(left - right); + val = a == b || left <= right || diff <= 1e-15; + } + out_data[i] = val; + // if (!val) *out_data = false; + } +} template struct GetTensorValue { T operator()(const phi::GPUContext& dev_ctx, diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 6632738695d7eb51de3871faab6de45df94e9a33..58a3b1fc0ea89766e45c2d7a3266abb91de79e49 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -1320,8 +1320,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): two tensors are elementwise equal within a tolerance. Args: - x(Tensor): The input tensor, it's data type should be float16, float32, float64. - y(Tensor): The input tensor, it's data type should be float16, float32, float64. + x(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128. + y(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` . @@ -1359,10 +1359,16 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): return _C_ops.isclose(x, y, rtol, atol, equal_nan) else: check_variable_and_dtype( - x, "input", ['float16', 'float32', 'float64'], 'isclose' + x, + "input", + ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'isclose', ) check_variable_and_dtype( - y, "input", ['float16', 'float32', 'float64'], 'isclose' + y, + "input", + ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'isclose', ) check_type(rtol, 'rtol', float, 'isclose') check_type(atol, 'atol', float, 'isclose') diff --git a/test/legacy_test/test_isclose_op.py b/test/legacy_test/test_isclose_op.py index 3cb84a2b3c1e3dc57f4684cb9100edfadbefbf96..2074a160c5b3d9693944b18acc711c4a40202855 100644 --- a/test/legacy_test/test_isclose_op.py +++ b/test/legacy_test/test_isclose_op.py @@ -259,6 +259,69 @@ class TestIscloseOpFloat64(TestIscloseOp): self.check_output() +class TestIscloseOpCp64(unittest.TestCase): + def test_cp64(self): + x_data = ( + np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) + ).astype(np.complex64) + y_data = ( + np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) + ).astype(np.complex64) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[10, 10], name='x', dtype=np.complex64) + y = paddle.static.data(shape=[10, 10], name='y', dtype=np.complex64) + out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) + + +class TestIscloseOpCp128(unittest.TestCase): + def test_cp128(self): + x_data = ( + np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) + ).astype(np.complex128) + y_data = ( + np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10) + ).astype(np.complex128) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + shape=[10, 10], name='x', dtype=np.complex128 + ) + y = paddle.static.data( + shape=[10, 10], name='y', dtype=np.complex128 + ) + out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) + + +class TestIscloseOpComplex64(TestIscloseOp): + def set_args(self): + self.input = np.array([10.1 + 0.1j]).astype(np.complex64) + self.other = np.array([10 + 0j]).astype(np.complex64) + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + +class TestIscloseOpComplex128(TestIscloseOp): + def set_args(self): + self.input = np.array([10.1 + 0.1j]).astype(np.complex128) + self.other = np.array([10 + 0j]).astype(np.complex128) + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + def test_check_output(self): + self.check_output() + + class TestIscloseOpLargeDimInput(TestIscloseOp): def set_args(self): self.input = np.array(np.zeros([2048, 1024])).astype("float64")