未验证 提交 d53972fd 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

【complex op】No.7 add complex support for isclose (#56723)

* add complex support for isclose

* add complex test for isclose

* fix template complie issue

* fix cuda compilation error

* fix type typo

* fix error for complex's abs

* add complex dtype into input

* fix ut
上级 3f5d0083
...@@ -18,5 +18,11 @@ ...@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h" #include "paddle/phi/kernels/impl/isclose_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(isclose,
isclose, CPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} CPU,
ALL_LAYOUT,
phi::IscloseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(isclose, ...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(isclose,
phi::IscloseKernel, phi::IscloseKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -86,6 +87,40 @@ struct IscloseFunctor<phi::CPUContext, T> { ...@@ -86,6 +87,40 @@ struct IscloseFunctor<phi::CPUContext, T> {
} }
}; };
template <typename T>
struct IscloseFunctor<phi::CPUContext, phi::dtype::complex<T>> {
void operator()(const phi::CPUContext& ctx,
const DenseTensor& in,
const DenseTensor& other,
const double rtol,
const double atol,
bool equal_nan,
DenseTensor* output) {
auto* in_a = in.data<phi::dtype::complex<T>>();
auto* in_b = other.data<phi::dtype::complex<T>>();
auto* out_data = ctx.template Alloc<bool>(output);
auto num = in.numel();
// *out_data = true;
for (int i = 0; i < num; i++) {
out_data[i] = true;
}
for (int i = 0; i < num; i++) {
const phi::dtype::complex<T> a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = abs(a - b);
T right = atol + rtol * abs(b);
T diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
// *out_data &= val;
out_data[i] = val;
}
}
}
};
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
template <typename T> template <typename T>
__global__ void IscloseCUDAKernel(const T* in_data, __global__ void IscloseCUDAKernel(const T* in_data,
...@@ -113,7 +148,59 @@ __global__ void IscloseCUDAKernel(const T* in_data, ...@@ -113,7 +148,59 @@ __global__ void IscloseCUDAKernel(const T* in_data,
// if (!val) *out_data = false; // if (!val) *out_data = false;
} }
} }
template <>
__global__ void IscloseCUDAKernel<phi::dtype::complex<float>>(
const phi::dtype::complex<float>* in_data,
const phi::dtype::complex<float>* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::dtype::complex<float> a = in_data[i];
const phi::dtype::complex<float> b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
float left = abs(a - b);
float right = atol + rtol * abs(b);
float diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
// if (!val) *out_data = false;
}
}
template <>
__global__ void IscloseCUDAKernel<phi::dtype::complex<double>>(
const phi::dtype::complex<double>* in_data,
const phi::dtype::complex<double>* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::dtype::complex<double> a = in_data[i];
const phi::dtype::complex<double> b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
double left = abs(a - b);
double right = atol + rtol * abs(b);
double diff = abs(left - right);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
// if (!val) *out_data = false;
}
}
template <typename T> template <typename T>
struct GetTensorValue<phi::GPUContext, T> { struct GetTensorValue<phi::GPUContext, T> {
T operator()(const phi::GPUContext& dev_ctx, T operator()(const phi::GPUContext& dev_ctx,
......
...@@ -1320,8 +1320,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -1320,8 +1320,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance. two tensors are elementwise equal within a tolerance.
Args: Args:
x(Tensor): The input tensor, it's data type should be float16, float32, float64. x(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
y(Tensor): The input tensor, it's data type should be float16, float32, float64. y(Tensor): The input tensor, it's data type should be float16, float32, float64, complex64, complex128.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` . equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
...@@ -1359,10 +1359,16 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -1359,10 +1359,16 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
return _C_ops.isclose(x, y, rtol, atol, equal_nan) return _C_ops.isclose(x, y, rtol, atol, equal_nan)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, "input", ['float16', 'float32', 'float64'], 'isclose' x,
"input",
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'isclose',
) )
check_variable_and_dtype( check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'isclose' y,
"input",
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'isclose',
) )
check_type(rtol, 'rtol', float, 'isclose') check_type(rtol, 'rtol', float, 'isclose')
check_type(atol, 'atol', float, 'isclose') check_type(atol, 'atol', float, 'isclose')
......
...@@ -259,6 +259,69 @@ class TestIscloseOpFloat64(TestIscloseOp): ...@@ -259,6 +259,69 @@ class TestIscloseOpFloat64(TestIscloseOp):
self.check_output() self.check_output()
class TestIscloseOpCp64(unittest.TestCase):
def test_cp64(self):
x_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex64)
y_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex64)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype=np.complex64)
y = paddle.static.data(shape=[10, 10], name='y', dtype=np.complex64)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])
class TestIscloseOpCp128(unittest.TestCase):
def test_cp128(self):
x_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex128)
y_data = (
np.random.rand(10, 10) + 1.0j * np.random.rand(10, 10)
).astype(np.complex128)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
shape=[10, 10], name='x', dtype=np.complex128
)
y = paddle.static.data(
shape=[10, 10], name='y', dtype=np.complex128
)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])
class TestIscloseOpComplex64(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1 + 0.1j]).astype(np.complex64)
self.other = np.array([10 + 0j]).astype(np.complex64)
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
class TestIscloseOpComplex128(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1 + 0.1j]).astype(np.complex128)
self.other = np.array([10 + 0j]).astype(np.complex128)
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
def test_check_output(self):
self.check_output()
class TestIscloseOpLargeDimInput(TestIscloseOp): class TestIscloseOpLargeDimInput(TestIscloseOp):
def set_args(self): def set_args(self):
self.input = np.array(np.zeros([2048, 1024])).astype("float64") self.input = np.array(np.zeros([2048, 1024])).astype("float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册