diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 401d2fd158a5d1b1456c632dc5c59aebd78d9c8b..901c1fed628d33df490d6fab289247ff8f955f3a 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL( - einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {} +PD_REGISTER_KERNEL(einsum, + CPU, + ALL_LAYOUT, + phi::EinsumKernelRaw, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index d1f4c6590387a81464a3fdceec0442934e8b2940..b3706710c40e33518a25199c09c5c636ece29e01 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(einsum, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index 97f3eef51a5bfe9f2c33fee18b60ad4099fd5648..224f44d74864b90f0107587953698017d4c0fb7c 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -541,5 +541,19 @@ class TestBF16(unittest.TestCase): self.assertEqual(C.item(), 8.0) +class TestComplex(unittest.TestCase): + """ + EinsumOp support Complex type + """ + + def test_shape(self): + a = paddle.rand([4, 4]) + b = paddle.rand([4, 4]) + c = paddle.einsum('xy,yz->xz', a, b) + a = paddle.cast(a, 'complex64') + b = paddle.cast(b, 'complex64') + c = paddle.einsum('xy,yz->xz', a, b) + + if __name__ == "__main__": unittest.main()