From e930c5763e063c1e009798a487b87ebb2d6092eb Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 1 Nov 2022 15:58:58 +0800 Subject: [PATCH] [EinsumOp] Einsum support complex grad (#47514) * Einsum Support Complex * code fix * add unittest for complex grad with einsum * set rtol=1e-4 * fix --- paddle/phi/kernels/impl/einsum_grad_impl.h | 8 ++- .../fluid/tests/unittests/test_einsum_v2.py | 65 ++++++++++++++++++- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index bf27f3ef2b..816badcd79 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -15,6 +15,7 @@ #include "paddle/fluid/platform/profiler.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/tile_kernel.h" #include "paddle/utils/string/string_helper.h" @@ -177,11 +178,12 @@ void EinsumGradKernel(const Context& dev_ctx, auto operands_for_A = std::vector(); auto operands_for_B = std::vector(); DenseTensor dA, dB; + auto out_grad_conj = Conj(dev_ctx, out_grad); // dA = einsum(B, dC) operands_for_A.push_back(x[1]); - operands_for_A.push_back(&out_grad); + operands_for_A.push_back(&out_grad_conj); // dB = einsum(dC, A) - operands_for_B.push_back(&out_grad); + operands_for_B.push_back(&out_grad_conj); operands_for_B.push_back(x[0]); DenseTensor before_tile; @@ -219,6 +221,7 @@ void EinsumGradKernel(const Context& dev_ctx, ellipsis_dims[0], ops[0], dA); + *(x_grad[0]) = Conj(dev_ctx, *x_grad[0]); } if (x_grad[1]) { *(x_grad[1]) = PerformTileAndReduction(dev_ctx, @@ -228,6 +231,7 @@ void EinsumGradKernel(const Context& dev_ctx, ellipsis_dims[1], ops[1], dB); + *(x_grad[1]) = Conj(dev_ctx, *x_grad[1]); } } } diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index c7d2f9c76b..f45f9ace1c 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -580,7 +580,7 @@ class TestSimpleUndiagonal(unittest.TestCase): A = paddle.to_tensor(np.array([1.0, 2.0])) A_expect = paddle.to_tensor([[1.0, 0.0], [0.0, 2.0]]) A_actual = paddle.einsum('i->ii', A) - np.array_equal(A_expect.numpy(), A_actual.numpy()) + assert np.array_equal(A_expect.numpy(), A_actual.numpy()) class TestSimpleUndiagonal2(unittest.TestCase): @@ -594,7 +594,68 @@ class TestSimpleUndiagonal2(unittest.TestCase): B = paddle.to_tensor(np.array([1.0, 1.0])) A_expect = paddle.to_tensor([[2.0, 0.0], [0.0, 4.0]]) A_actual = paddle.einsum('i,j->ii', A, B) - np.array_equal(A_expect.numpy(), A_actual.numpy()) + assert np.array_equal(A_expect.numpy(), A_actual.numpy()) + + +class TestSimpleComplexGrad(unittest.TestCase): + """ + EinsumOp support complex grad. but op_test don't support numeric grad for complex dtype. + """ + + def test_shape(self): + paddle.disable_static() + A = paddle.to_tensor( + [ + [ + [-1.08644637 + 1.30794563j], + [-0.89606513 + 1.84546043j], + [-0.30629937 + 0.82911495j], + ], + [ + [-1.33993366 - 0.02329881j], + [-1.20658558 - 0.20856395j], + [-0.64172681 - 0.91661975j], + ], + ] + ) + + B = paddle.to_tensor( + [ + [[-1.07474258 + 0.39477287j], [-0.08614349 - 0.38770082j]], + [[1.17583854 + 0.58840176j], [-1.63509173 - 1.43329882j]], + [[1.228194 - 0.32357468j], [1.07638625 + 1.25298469j]], + ] + ) + + dOut = paddle.to_tensor( + [ + [[-0.73074259 - 0.1632133j], [1.42848507 - 0.96410727j]], + [[0.94465389 - 0.34264733j], [-0.26400278 + 0.04890404j]], + ] + ) + + d_expect = paddle.to_tensor( + [ + [ + [0.971658 + 1.100766j], + [-1.909121 + 3.861908j], + [-0.515092 - 3.264529j], + ], + [ + [-1.146746 - 0.111233j], + [1.270721 - 1.417091j], + [1.048197 + 0.268260j], + ], + ] + ) + + A.stop_gradient = False + B.stop_gradient = False + Out = paddle.einsum('iox,ojx->ijx', A, B) + dA = paddle.grad(Out, A, dOut)[0] + np.testing.assert_allclose( + dA.numpy(), d_expect.numpy(), rtol=1e-6, atol=0 + ) if __name__ == "__main__": -- GitLab