未验证 提交 e930c576 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Einsum support complex grad (#47514)

* Einsum Support Complex

* code fix

* add unittest for complex grad with einsum

* set rtol=1e-4

* fix
上级 3592ba8c
......@@ -15,6 +15,7 @@
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/kernels/tile_kernel.h"
#include "paddle/utils/string/string_helper.h"
......@@ -177,11 +178,12 @@ void EinsumGradKernel(const Context& dev_ctx,
auto operands_for_A = std::vector<const DenseTensor*>();
auto operands_for_B = std::vector<const DenseTensor*>();
DenseTensor dA, dB;
auto out_grad_conj = Conj<T, Context>(dev_ctx, out_grad);
// dA = einsum(B, dC)
operands_for_A.push_back(x[1]);
operands_for_A.push_back(&out_grad);
operands_for_A.push_back(&out_grad_conj);
// dB = einsum(dC, A)
operands_for_B.push_back(&out_grad);
operands_for_B.push_back(&out_grad_conj);
operands_for_B.push_back(x[0]);
DenseTensor before_tile;
......@@ -219,6 +221,7 @@ void EinsumGradKernel(const Context& dev_ctx,
ellipsis_dims[0],
ops[0],
dA);
*(x_grad[0]) = Conj<T, Context>(dev_ctx, *x_grad[0]);
}
if (x_grad[1]) {
*(x_grad[1]) = PerformTileAndReduction<T, Context>(dev_ctx,
......@@ -228,6 +231,7 @@ void EinsumGradKernel(const Context& dev_ctx,
ellipsis_dims[1],
ops[1],
dB);
*(x_grad[1]) = Conj<T, Context>(dev_ctx, *x_grad[1]);
}
}
}
......
......@@ -580,7 +580,7 @@ class TestSimpleUndiagonal(unittest.TestCase):
A = paddle.to_tensor(np.array([1.0, 2.0]))
A_expect = paddle.to_tensor([[1.0, 0.0], [0.0, 2.0]])
A_actual = paddle.einsum('i->ii', A)
np.array_equal(A_expect.numpy(), A_actual.numpy())
assert np.array_equal(A_expect.numpy(), A_actual.numpy())
class TestSimpleUndiagonal2(unittest.TestCase):
......@@ -594,7 +594,68 @@ class TestSimpleUndiagonal2(unittest.TestCase):
B = paddle.to_tensor(np.array([1.0, 1.0]))
A_expect = paddle.to_tensor([[2.0, 0.0], [0.0, 4.0]])
A_actual = paddle.einsum('i,j->ii', A, B)
np.array_equal(A_expect.numpy(), A_actual.numpy())
assert np.array_equal(A_expect.numpy(), A_actual.numpy())
class TestSimpleComplexGrad(unittest.TestCase):
"""
EinsumOp support complex grad. but op_test don't support numeric grad for complex dtype.
"""
def test_shape(self):
paddle.disable_static()
A = paddle.to_tensor(
[
[
[-1.08644637 + 1.30794563j],
[-0.89606513 + 1.84546043j],
[-0.30629937 + 0.82911495j],
],
[
[-1.33993366 - 0.02329881j],
[-1.20658558 - 0.20856395j],
[-0.64172681 - 0.91661975j],
],
]
)
B = paddle.to_tensor(
[
[[-1.07474258 + 0.39477287j], [-0.08614349 - 0.38770082j]],
[[1.17583854 + 0.58840176j], [-1.63509173 - 1.43329882j]],
[[1.228194 - 0.32357468j], [1.07638625 + 1.25298469j]],
]
)
dOut = paddle.to_tensor(
[
[[-0.73074259 - 0.1632133j], [1.42848507 - 0.96410727j]],
[[0.94465389 - 0.34264733j], [-0.26400278 + 0.04890404j]],
]
)
d_expect = paddle.to_tensor(
[
[
[0.971658 + 1.100766j],
[-1.909121 + 3.861908j],
[-0.515092 - 3.264529j],
],
[
[-1.146746 - 0.111233j],
[1.270721 - 1.417091j],
[1.048197 + 0.268260j],
],
]
)
A.stop_gradient = False
B.stop_gradient = False
Out = paddle.einsum('iox,ojx->ijx', A, B)
dA = paddle.grad(Out, A, dOut)[0]
np.testing.assert_allclose(
dA.numpy(), d_expect.numpy(), rtol=1e-6, atol=0
)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册