diff --git a/python/paddle/fluid/tests/unittests/test_einsum_0d_tensor.py b/python/paddle/fluid/tests/unittests/test_einsum_0d_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..a401b1300bfd6164aa2f0cd6eec16bf15d500dcd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_einsum_0d_tensor.py @@ -0,0 +1,197 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +from numpy.testing import assert_allclose + +import paddle + +os.environ['NVIDIA_TF32_OVERRIDE'] = "0" + + +class Test0DCase0(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_func(self): + x = paddle.rand([]) + x.stop_gradient = False + y = paddle.rand([]) + y.stop_gradient = False + z = paddle.einsum("...,...->...", x, y) + assert_allclose( + z.numpy(), + np.einsum('...,...->...', x.numpy(), y.numpy()), + atol=1e-6, + ) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [] + assert y.grad.shape == [] + + +class Test0DCase1(Test0DCase0): + def test_func(self): + x = paddle.rand([]) + x.stop_gradient = False + y = paddle.rand([2, 2]) + y.stop_gradient = False + z = paddle.einsum("...,ij->...", x, y) + assert_allclose( + z.numpy(), np.einsum('...,ij->...', x.numpy(), y.numpy()), atol=1e-6 + ) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [] + assert y.grad.shape == [2, 2] + + +class Test0DCase2(Test0DCase0): + def test_func(self): + x = paddle.rand([2, 2]) + x.stop_gradient = False + y = paddle.rand([2, 2]) + y.stop_gradient = False + z = paddle.einsum("ij,ij->", x, y) + assert_allclose( + z.numpy(), np.einsum('ij,ij->', x.numpy(), y.numpy()), atol=1e-6 + ) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [2, 2] + assert y.grad.shape == [2, 2] + + +class Test0DCase3(Test0DCase0): + def test_func(self): + x = paddle.rand([2, 2]) + x.stop_gradient = True + y = paddle.rand([2, 2]) + y.stop_gradient = False + z = paddle.einsum("ij,ij->", x, y) + assert_allclose( + z.numpy(), np.einsum('ij,ij->', x.numpy(), y.numpy()), atol=1e-6 + ) + z.mean().backward() + assert z.shape == [] + assert x.grad is None + assert y.grad.shape == [2, 2] + + +class Test0DCase4(Test0DCase0): + def test_func(self): + x = paddle.rand([]) + x.stop_gradient = False + z = paddle.einsum("...->...", x) + assert_allclose(z.numpy(), np.einsum('...->...', x.numpy()), atol=1e-6) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [] + assert x.grad.numpy() == 1.0 + + +class Test0DCase5(Test0DCase0): + def test_func(self): + x = paddle.rand([2, 2]) + x.stop_gradient = False + y = paddle.rand([2, 2]) + y.stop_gradient = False + z = paddle.einsum("i...j, i...j->...", x, y) + assert_allclose( + z.numpy(), + np.einsum('i...j, i...j->...', x.numpy(), y.numpy()), + atol=1e-6, + ) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [2, 2] + assert y.grad.shape == [2, 2] + + +class Test0DCase6(Test0DCase0): + def test_func(self): + x = paddle.rand([2, 2]) + x.stop_gradient = False + z = paddle.einsum("ij->", x) + assert_allclose(z.numpy(), np.einsum('ij->', x.numpy()), atol=1e-6) + z.mean().backward() + assert z.shape == [] + assert x.grad.shape == [2, 2] + + +class Test0DCase7(Test0DCase0): + def test_func(self): + """ + 3 operands. + """ + x = paddle.rand([2, 2]) + y = paddle.rand([]) + z = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + z.stop_gradient = False + o = paddle.einsum("ij...,...,...->...", x, y, z) + assert_allclose( + o.numpy(), + np.einsum("ij...,...,...->...", x.numpy(), y.numpy(), z.numpy()), + atol=1e-6, + ) + o.mean().backward() + assert o.shape == [] + assert x.grad.shape == [2, 2] + assert y.grad.shape == [] + assert z.grad.shape == [] + + +class Test0DCase8(Test0DCase0): + def test_func(self): + """ + 3 operands. + """ + x = paddle.rand([2, 2]) + y = paddle.rand([]) + z = paddle.rand([]) + e = paddle.rand([3, 1]) + x.stop_gradient = False + y.stop_gradient = False + z.stop_gradient = False + e.stop_gradient = False + o = paddle.einsum("ij...,...,..., km->...", x, y, z, e) + assert_allclose( + o.numpy(), + np.einsum( + "ij...,...,...,km->...", + x.numpy(), + y.numpy(), + z.numpy(), + e.numpy(), + ), + atol=1e-6, + ) + o.mean().backward() + assert o.shape == [] + assert x.grad.shape == [2, 2] + assert y.grad.shape == [] + assert z.grad.shape == [] + assert e.grad.shape == [3, 1] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 4a69cd76304bbcb10ca9d59da0744f92b44dc380..b98a3094cafa6beac0a256184a9862967f7b229f 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -59,9 +59,7 @@ def parse_op_labels(labelstr, operand): labelstr.replace('...', '', 1).find('.') == -1 ), "Invalid equation: `.` is found outside of an ellipsis." - # Check shape. Note, in Paddle a tensor rank is always nonzero ndims = len(operand.shape) - assert ndims > 0 full_labelstr = labelstr.replace('...', '.' * (ndims - len(labelstr) + 3)) @@ -743,20 +741,25 @@ def parse_fake_shape(equation, operands, labels): list of shape """ + origin_labels = map(lambda x: x.strip(), equation.split(',')) shaped = collections.namedtuple('shaped', ['shape']) - def fake_shape(label, op): + def fake_shape(ori_label, label, op): + """ + 1. ori_label is the original labels, not aligned by '....' + 2. if the '...' is evalulated to empty list, there is no '.' in label + """ assert len(op.shape) == len(label), ( "length of shape and length of label must be the same, but received %d != %d" % (len(op.shape), len(label)) ) fakes = [s for i, (l, s) in enumerate(zip(label, op.shape)) if l != '.'] fakes = list(map(abs, fakes)) # make -1 -> 1 - if '.' in label: - fakes.insert(label.index('.'), 1) + if '.' in ori_label: + fakes.insert(ori_label.index('.'), 1) return shaped(fakes) - out = list(map(fake_shape, labels, operands)) + out = list(map(fake_shape, origin_labels, labels, operands)) return out @@ -782,7 +785,7 @@ def gen_equation_for_opteinsum(lhs, rhs): if c not in used: return c raise ValueError( - "You have used all `a` - `z`, there can't find a unused for einsum optimization" + "You have used all `a` - `z`, there can't find a unused char for einsum optimization" ) cnt = collections.Counter(lhs)