未验证 提交 6a7dfdd0 编写于 作者: X xiongkun 提交者: GitHub

einsum support complex (#44212)

einsum support complex and add unittest.
上级 449ea33d
......@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(
einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {}
PD_REGISTER_KERNEL(einsum,
CPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(einsum,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -541,5 +541,19 @@ class TestBF16(unittest.TestCase):
self.assertEqual(C.item(), 8.0)
class TestComplex(unittest.TestCase):
"""
EinsumOp support Complex type
"""
def test_shape(self):
a = paddle.rand([4, 4])
b = paddle.rand([4, 4])
c = paddle.einsum('xy,yz->xz', a, b)
a = paddle.cast(a, 'complex64')
b = paddle.cast(b, 'complex64')
c = paddle.einsum('xy,yz->xz', a, b)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册