未验证 提交 b9828bdf 编写于 作者: Y Yibing Liu 提交者: GitHub

Use faster algorithm for complex matmul, test=develop (#26231)

上级 5113aae6
......@@ -56,20 +56,20 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
# [1.+5.j 5.+9.j]
"""
# x = a + bi, y = c + di
# mm(x, y) = mm(a, c) - mm(b, d) + (mm(a, d) + mm(b, c))i
# P1 = ac; P2 = (a + b)(c + d); P3 = bd; then mm(x, y) = (P1-P3) + (P2-P1-P3)j
complex_variable_exists([x, y], "matmul")
a, b = (x.real, x.imag) if is_complex(x) else (x, None)
c, d = (y.real, y.imag) if is_complex(y) else (y, None)
ac = layers.matmul(a, c, transpose_x, transpose_y, alpha, name)
P1 = layers.matmul(a, c, transpose_x, transpose_y, alpha, name)
if is_real(b) and is_real(d):
bd = layers.matmul(b, d, transpose_x, transpose_y, alpha, name)
real = ac - bd
imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name) + \
layers.matmul(b, c, transpose_x, transpose_y, alpha, name)
P2 = layers.matmul(a + b, c + d, transpose_x, transpose_y, alpha, name)
P3 = layers.matmul(b, d, transpose_x, transpose_y, alpha, name)
real = P1 - P3
imag = P2 - P1 - P3
elif is_real(b):
real = ac
real = P1
imag = layers.matmul(b, c, transpose_x, transpose_y, alpha, name)
else:
real = ac
real = P1
imag = layers.matmul(a, d, transpose_x, transpose_y, alpha, name)
return ComplexVariable(real, imag)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册