From 6a7dfdd0c574a90a31241cdebf11f2f380ff5148 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 11 Jul 2022 17:04:05 +0800 Subject: [PATCH] einsum support complex (#44212) einsum support complex and add unittest. --- paddle/phi/kernels/cpu/einsum_kernel.cc | 10 ++++++++-- paddle/phi/kernels/gpu/einsum_kernel.cu | 4 +++- .../paddle/fluid/tests/unittests/test_einsum_v2.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 401d2fd158a..901c1fed628 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL( - einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {} +PD_REGISTER_KERNEL(einsum, + CPU, + ALL_LAYOUT, + phi::EinsumKernelRaw, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index d1f4c659038..b3706710c40 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(einsum, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index 97f3eef51a5..224f44d7486 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -541,5 +541,19 @@ class TestBF16(unittest.TestCase): self.assertEqual(C.item(), 8.0) +class TestComplex(unittest.TestCase): + """ + EinsumOp support Complex type + """ + + def test_shape(self): + a = paddle.rand([4, 4]) + b = paddle.rand([4, 4]) + c = paddle.einsum('xy,yz->xz', a, b) + a = paddle.cast(a, 'complex64') + b = paddle.cast(b, 'complex64') + c = paddle.einsum('xy,yz->xz', a, b) + + if __name__ == "__main__": unittest.main() -- GitLab