diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 8cd8d94d6b3893698146396651d82509e96b4406..8fdde1ccdc058be3ada3736a15f7ec249e8b868b 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/ddim.h" -#include "paddle/phi/kernels/impl/einsum_impl.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -85,7 +85,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor, - PD_INFER_META(phi::EinsumInferShape)); + PD_INFER_META(phi::EinsumInferMeta)); REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker, EinsumInferShapeFunctor, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 367129cd7267660e4d6c1009f13d395c3227794f..eda461be95a406a19a6049fda57acab8d19ada01 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/strided_slice.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/impl/einsum_impl.h" namespace phi { @@ -398,6 +399,45 @@ void EighInferMeta(const MetaTensor& x, out_v->set_dims(input_dim); } +void EinsumInferMeta(const std::vector& inputs, + const std::string& equation, + MetaTensor* out) { + // collect the following informations to prepare einsum. + LabelMap labelshape(0); + LabelMap labeltype(LabelType::Reduction); + std::vector label2perms(inputs.size(), LabelMap(-1)); + std::vector all_labels; + std::vector broadcast_dims; + std::vector output_dims; + std::vector> ellipsis_dims(2); + + std::vector input_dims; + for (auto& i : inputs) { + input_dims.push_back(i->dims()); + } + std::string right; + ParseEinsumEquation(equation, + input_dims, + &labelshape, + &labeltype, + &all_labels, + &label2perms, + &ellipsis_dims, + &broadcast_dims, + &output_dims, + &right); + + VLOG(3) << "Einsum Infershape: input dims:" + << paddle::string::join_strings(input_dims, "\n"); + VLOG(3) << "Einsum Infershape: equation:" << equation; + VLOG(3) << "Einsum Infershape: all_labels:" + << paddle::string::join_strings(all_labels, ","); + VLOG(3) << "Einsum Infershape: output dims:" + << paddle::string::join_strings(output_dims, ","); + VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype); + VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); +} + void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 97fa932eed584d8941cd9795497a6d138c1b3616..559857bd6ce9bd94c6d94e0631f2cd326edd710e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); +void EinsumInferMeta(const std::vector& inputs, + const std::string& equation, + MetaTensor* out); + void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out); diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index d4be007a07fc0ed88f966ccfbd309bc7687a4ca1..73940a45cbde2b5b5f301b6a1f7d7c328c1b53c1 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -21,6 +20,7 @@ #include "paddle/utils/string/string_helper.h" namespace phi { + // check the validation of the Einsum equation. // 1. the label must between 'a' - 'z'. // 2. the dim of the same label must be same. @@ -302,45 +302,6 @@ inline static void ParseEinsumEquation( } } -inline void EinsumInferShape(const std::vector& inputs, - const std::string& equation, - MetaTensor* out) { - // collect the following informations to prepare einsum. - LabelMap labelshape(0); - LabelMap labeltype(LabelType::Reduction); - std::vector label2perms(inputs.size(), LabelMap(-1)); - std::vector all_labels; - std::vector broadcast_dims; - std::vector output_dims; - std::vector> ellipsis_dims(2); - - std::vector input_dims; - for (auto& i : inputs) { - input_dims.push_back(i->dims()); - } - std::string right; - ParseEinsumEquation(equation, - input_dims, - &labelshape, - &labeltype, - &all_labels, - &label2perms, - &ellipsis_dims, - &broadcast_dims, - &output_dims, - &right); - - VLOG(3) << "Einsum Infershape: input dims:" - << paddle::string::join_strings(input_dims, "\n"); - VLOG(3) << "Einsum Infershape: equation:" << equation; - VLOG(3) << "Einsum Infershape: all_labels:" - << paddle::string::join_strings(all_labels, ","); - VLOG(3) << "Einsum Infershape: output dims:" - << paddle::string::join_strings(output_dims, ","); - VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype); - VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); -} - template std::vector GetLabelIndexByType(const std::vector& all_labels, const LabelMap& type, @@ -394,6 +355,13 @@ DenseTensor PerformReduction(const Context& dev_ctx, return Sum(dev_ctx, tensor, indices, tensor.dtype(), true); } +inline bool is_no_need_transpose(const std::vector& axis) { + for (size_t i = 0; i < axis.size(); ++i) { + if (i != static_cast(axis[i])) return false; + } + return true; +} + template DenseTensor PerformTranspose(const Context& dev_ctx, const DenseTensor& tensor, @@ -401,12 +369,6 @@ DenseTensor PerformTranspose(const Context& dev_ctx, const std::vector& all_labels, const std::vector& ellipsis, const LabelMap& label2type) { - auto is_no_need_transpose = [](std::vector& axis) { - for (size_t i = 0; i < axis.size(); ++i) { - if (i != size_t(axis[i])) return false; - } - return true; - }; auto axis = GetLabelIndexByType( all_labels, label2type, label2perm, ellipsis, LabelType::ALL_TYPE); VLOG(5) << "PerformTranspose: " << paddle::string::join_strings(axis, ","); @@ -496,9 +458,9 @@ void TransposeToOutput(const Context& dev_ctx, axis.push_back(it - all_labels.begin() + offset); } } + if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); VLOG(5) << "call TransposeToOutput: with axis: " << paddle::string::join_strings(axis, ","); - if (axis.size() == 0) return output->ShareBufferWith(to_trans); return TransposeKernel(dev_ctx, to_trans, axis, output); } diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..63acaf63969139324f2f7a60784707285b9cb3a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -0,0 +1,468 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import contextlib +import unittest +import paddle +from paddle.fluid import core + +import os +os.environ['FLAGS_new_einsum'] = "1" + + +def error_trans(func, *args, **kargs): + """ + transport C++ exception into Python exception. + because einsum_v2 raise different exception with einsum_v1. + """ + try: + out = func(*args, **kargs) + except ValueError as e: + if "Same label have different shapes" in str(e): + raise AssertionError("Invalid operands: label i " + "corresponds to non-broadcastable dimensions.") + + +class TestErrors(unittest.TestCase): + def setUp(self): + pass + + def test_diagonalize_errors(self): + a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') + a = paddle.to_tensor(a) + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): + paddle.einsum('...ii->...i', a) + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): + paddle.einsum('i...i', a) + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): + paddle.einsum('i...i->i...', a) + + def test_param_errors(self): + a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') + a = paddle.to_tensor(a) + with self.assertRaisesRegex( + AssertionError, + ("Required at least one operand in Einsum API, but received 0 ")): + paddle.einsum('ijk') + with self.assertRaisesRegex(AssertionError, ( + 'Invalid equation: multiple `->` were found.')): + paddle.einsum('i -> j -> k', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: the number of operands is 2, " + "but found 3 segments in the label equation.")): + paddle.einsum('i,j,k', a, a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: the number of operands is 2, " + "but found 1 segments in the label equation.")): + paddle.einsum('ij -> k', a, a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: the number of operands is 1, " + "but found 2 segments in the label equation.")): + paddle.einsum('i, -> k', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: the label string '' misses dimensions.")): + paddle.einsum('->', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: the label string 'i' misses dimensions.")): + paddle.einsum('i', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: _ is not a valid label, " + "which should be letters.")): + paddle.einsum('i_', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: `.` is found outside of an ellipsis.")): + paddle.einsum('i..j', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: `.` is found outside of an ellipsis.")): + paddle.einsum('...k...', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: missing ellipsis in output labels.")): + paddle.einsum('i...->i', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid equation: duplicate output labels are found.")): + paddle.einsum('i...->i...i', a) + with self.assertRaisesRegex(AssertionError, ( + "Invalid operands: label i " + "corresponds to non-broadcastable dimensions.")): + error_trans(paddle.einsum, 'ij...,ji...', a, a) + + +class TestEinsum(unittest.TestCase): + @classmethod + def setUpClass(cls): + np.random.seed(12345) + + cls.TEST_SAMPLES = { + "a": np.random.rand(1, 1), + "b": np.random.rand(1), + "x": np.random.rand(5), + "y": np.random.rand(7), + "A": np.random.rand(4, 5), + "B": np.random.rand(2, 5), + "C": np.random.rand(3, 7), + "D": np.random.rand(3, 4, 5), + "E": np.random.rand(3, 5, 2), + "F": np.random.rand(2, 4, 5, 3), + "G": np.random.rand(4, 2, 5), + "H": np.random.rand(3, 2, 4), + "I": np.random.rand(2, 2), + "J": np.random.rand(1, 3, 5), + "K": np.random.rand(1, 2, 3, 4), + } + + def _get_place(self, force_to_use_cpu=False): + if force_to_use_cpu: + return core.CPUPlace() + else: + if core.is_compiled_with_cuda(): + return core.CUDAPlace(0) + return core.CPUPlace() + + def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8): + error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}' + self.assertTrue( + np.allclose( + actual, expect, rtol=rtol, atol=atol), + error_msg.format(paddle.get_device(), expect, actual, + self.__class__.__name__)) + + def setUp(self): + self.sample = {"paradigm": "i->", "data": ["x"]} + + def test_forward(self): + operands = [ + TestEinsum.TEST_SAMPLES[operand] for operand in self.sample["data"] + ] + expected_result = np.einsum(self.sample["paradigm"], *operands) + equation = self.sample["paradigm"] + + with paddle.fluid.dygraph.guard( + self._get_place(force_to_use_cpu=False)): + pd_operands = [paddle.to_tensor(operand) for operand in operands] + result = paddle.einsum(equation, *pd_operands) + self.check_output_equal(result.numpy(), expected_result) + + with paddle.fluid.dygraph.guard(self._get_place(force_to_use_cpu=True)): + pd_operands = [paddle.to_tensor(operand) for operand in operands] + result = paddle.einsum(equation, *pd_operands) + self.check_output_equal(result.numpy(), expected_result) + + +class TestEinsumVectorDot(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "i,i->", "data": ["x", "x"]} + + +class TestEinsumVectorMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "i,i->i", "data": ["x", "x"]} + + +class TestEinsumVectorOuter(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "i,j->ij", "data": ["x", "y"]} + + +class TestEinsumMatrixTranspose(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij->ji", "data": ["A"]} + + +class TestEinsumMatrixRowSum(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij->j", "data": ["A"]} + + +class TestEinsumMatrixColSum(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij->i", "data": ["A"]} + + +class TestEinsumMatrixEleMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]} + + +class TestEinsumDegenerateMatrixVecMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,j", "data": ["a", "b"]} + + +class TestEinsumMatrixVecMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]} + + +class TestEinsumMatrixMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,kj->ik", "data": ["A", "B"]} + + +class TestEinsumMatrixOuter(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,kl->ijkl", "data": ["A", "C"]} + + +class TestEinsumTensorBMM(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "bij,bjk->bik", "data": ["D", "E"]} + + +class TestEinsumTensorContract1(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijk,jk->i", "data": ["D", "A"]} + + +class TestEinsumTensorContract2(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijk,lk->ijl", "data": ["D", "B"]} + + +class TestEinsumTensorContract3(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "abcd,dfg->abcfg", "data": ["F", "D"]} + + +class TestEinsumTensorContract4(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijk,jk->ik", "data": ["D", "A"]} + + +class TestEinsumTensorContract5(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijk,jk->ij", "data": ["D", "A"]} + + +class TestEinsumTensorContract6(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ik, ijk->j", "data": ["A", "G"]} + + +class TestEinsumTensorContract7(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijk, ik->jk", "data": ["G", "A"]} + + +class TestEinsumEllipsis1(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "i...->...", "data": ["G"]} + + +class TestEinsumEllipsis2(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,...i->j...", "data": ["A", "H"]} + + +class TestEinsumEllipsis3(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "k...,jk", "data": ["F", "I"]} + + +class TestEinsumTestEinsumBilinear(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "bn,anm,bm->ba", "data": ["B", "E", "I"]} + + +class TestEinsumTestEinsumOthers1(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijkl, lmn->kmn", "data": ["F", "H"]} + + +class TestEinsumTestEinsumOthers2(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijkl, lmn->ijn", "data": ["F", "H"]} + + +class TestEinsumBatch1(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "blq,bhlk->bhlqk", "data": ["J", "K"]} + + +class TestNumpyTests(unittest.TestCase): + def setUp(self): + pass + + def _get_place(self, force_to_use_cpu=False): + if force_to_use_cpu: + return core.CPUPlace() + else: + if core.is_compiled_with_cuda(): + return core.CUDAPlace(0) + return core.CPUPlace() + + def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8): + error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}' + self.assertTrue( + np.allclose( + actual, expect, rtol=rtol, atol=atol), + error_msg.format(paddle.get_device(), expect, actual, + self.__class__.__name__)) + + def check_output(self, eqn, *ops): + expect = np.einsum(eqn, *ops) + with paddle.fluid.dygraph.guard( + self._get_place(force_to_use_cpu=False)): + pd_operands = [paddle.to_tensor(op) for op in ops] + actual = paddle.einsum(eqn, *pd_operands) + self.check_output_equal(actual.numpy(), expect) + + def test_sums(self): + for n in range(1, 17): + a = np.arange(n).astype('float') + self.check_output("i->", a) + + for n in range(1, 17): + a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float') + self.check_output("...i->...", a) + + for n in range(1, 17): + a = np.arange(2 * n).reshape(2, n).astype('float') + self.check_output("i...->...", a) + + for n in range(1, 17): + a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float') + self.check_output("i...->...", a) + + for n in range(1, 17): + a = np.arange(3 * n).reshape(3, n).astype('float') + b = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float') + self.check_output("..., ...", a, b) + + for n in range(1, 17): + a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float') + b = np.arange(n).astype('float') + self.check_output("...i, ...i", a, b) + + for n in range(1, 11): + a = np.arange(n * 3 * 2).reshape(n, 3, 2).astype('float') + b = np.arange(n).astype('float') + self.check_output("i..., i...", a, b) + + for n in range(1, 17): + a = (np.arange(3) + 1).astype('float') + b = (np.arange(n) + 1).astype('float') + self.check_output("i,j", a, b) + + for n in range(1, 17): + a = np.arange(4 * n).reshape(4, n).astype('float') + b = np.arange(n).astype('float') + self.check_output("ij, j", a, b) + + for n in range(1, 17): + a = np.arange(4 * n).reshape(4, n).astype('float') + b = np.arange(n).astype('float') + self.check_output("ji,j", a.T, b.T) + + for n in range(1, 17): + a = np.arange(4 * n).reshape(4, n).astype('float') + b = np.arange(n * 6).reshape(n, 6).astype('float') + self.check_output("ij,jk", a, b) + + a = np.arange(12).reshape(3, 4).astype('float') + b = np.arange(20).reshape(4, 5).astype('float') + c = np.arange(30).reshape(5, 6).astype('float') + self.check_output("ij,jk,kl", a, b, c) + + a = np.arange(60).reshape(3, 4, 5).astype('float') + b = np.arange(24).reshape(4, 3, 2).astype('float') + self.check_output("ijk, jil -> kl", a, b) + + for n in range(1, 25): + a = np.arange(n).astype('float') + self.check_output("...,...", a, a) + self.check_output("i,i", a, a) + + # TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. + #p = np.ones((10, 2)).astype('float') + #q = np.ones((1, 2)).astype('float') + #self.check_output('ij,ij->j', p, q) + + # TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. + #x = np.array([2., 3.]).astype('float') + #y = np.array([4.]).astype('float') + #self.check_output("i, i", x, y) + + # TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. + #p = np.ones((1, 5)) / 2 + #q = np.ones((5, 5)) / 2 + #self.check_output("...ij,...jk->...ik", p, p) + #self.check_output("...ij,...jk->...ik", p, q) + + x = np.eye(2).astype('float') + y = np.ones(2).astype('float') + self.check_output("ji,i->", x, y) + self.check_output("i,ij->", y, x) + self.check_output("ij,i->", x, y) + + def test_large_nops(self): + pass + # TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. + #a = np.arange(4 * 3 * 1 * 4).reshape(4, 3, 1, 4).astype('float') + #self.check_output('a...b,b...c,c...d', a, a, a) + #self.check_output('a...b,b...c,c...a', a, a, a) + #self.check_output('a...b,b...c,c...a', a, a, a) + #self.check_output('...ab,...ba,...ab,...ab', a, a, a, a) + + def test_static_graph(self): + paddle.enable_static() + fluid = paddle.fluid + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + a = paddle.static.data( + name='a', shape=[3, None, None, None], dtype='float') + b = paddle.static.data( + name='b', shape=[2, None, None, None], dtype='float') + c = paddle.static.data( + name='c', shape=[None, None, 2, None], dtype='float') + d = paddle.static.data( + name='d', shape=[None, None, 5], dtype='float') + e = paddle.static.data( + name='e', shape=[None, 2, None], dtype='float') + + outs = [] + outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b)) + outs.append(paddle.einsum('...ik, ...j', c, d)) + outs.append(paddle.einsum('...kj, ...ik', d, e)) + outs.append(paddle.einsum('ijk..., ikj', c, e)) + outs.append(paddle.einsum('ijk..., ikj->...ij', c, e)) + exe = fluid.Executor(self.place) + exe.run(startup) + a = np.arange(72).reshape(3, 2, 3, 4).astype('float') + b = np.arange(48).reshape(2, 2, 3, 4).astype('float') + c = np.arange(48).reshape(2, 3, 2, 4).astype('float') + d = np.arange(30).reshape(2, 3, 5).astype('float') + e = np.arange(12).reshape(2, 2, 3).astype('float') + feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e} + actual = exe.run(main, feed=feeds, fetch_list=[outs]) + expect = [] + expect.append(np.einsum("ibnd,jbnd->bnij", a, b)) + expect.append(np.einsum('...ik, ...j', c, d)) + expect.append(np.einsum('...kj, ...ik', d, e)) + expect.append(np.einsum('ijk..., ikj', c, e)) + expect.append(np.einsum('ijk..., ikj->...ij', c, e)) + for a, e in zip(actual, expect): + self.check_output_equal(a, e) + + +if __name__ == "__main__": + u diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index dd11477532d24daec0991cff1f4a99dd0d41641a..713a611f9f39a11146d04f1e8385991221c0e51d 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -24,6 +24,10 @@ from ..fluid.framework import _in_legacy_dygraph from paddle import _C_ops from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.layer_helper import LayerHelper +from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph +import collections +import string +import opt_einsum from paddle.common_ops_import import dygraph_only @@ -664,7 +668,138 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast): return plan +def preprocess(equation, *operands): + """ + check equation / raise error, default right labels generation + """ + equation = equation.replace(" ", "") + nop = len(operands) + assert nop > 0, "Required at least one operand in Einsum API, but received %s " % nop + + # Part the equation to left hand side and right hand side + lhs, *rhs = equation.lower().split('->') + assert len(rhs) < 2, "Invalid equation: multiple `->` were found." + + labels = parse_labels(lhs, operands) + # Note, we distinguish between 'ij->' and 'ij' by setting rhs to '' and None + rhs = rhs[0] if rhs else None + if rhs is None: + rhs = rhs_inference(lhs) + + assert len(lhs.split(',')) == len(operands), ( + f"Invalid equation: the number of operands is {len(operands)}, " + f"but found {len(lhs.split(','))} segments in the label equation.") + + assert not ('...' in lhs and '...' not in rhs + ), f'Invalid equation: missing ellipsis in output labels.' + + assert not (len(list(filter(has_duplicated_labels, lhs.split(',')))) > 0 + ), f'Duplicate labels are not supported.' + + assert not has_duplicated_labels( + rhs), f'Invalid equation: duplicate output labels are found.' + + return lhs, rhs, labels + + +def parse_fake_shape(equation, operands, labels): + """ + this shape is just used for operands planning. may differ with the original shape. + for example: + ... is replaced by 1 + -1 is replaced by 1 + Results + ------- + list of shape + """ + shaped = collections.namedtuple('shaped', ['shape']) + + def fake_shape(label, op): + assert len(op.shape) == len( + label + ), "length of shape and length of label must be the same, but received %d != %d" % ( + len(op.shape), len(label)) + fakes = [s for i, (l, s) in enumerate(zip(label, op.shape)) if l != '.'] + fakes = list(map(abs, fakes)) # make -1 -> 1 + if '.' in label: + fakes.insert(label.index('.'), 1) + return shaped(fakes) + + out = list(map(fake_shape, labels, operands)) + return out + + +def rhs_inference(lhs): + def is_free(key): + return cnt.get(key) == 1 and key not in ['.', ','] + + cnt = collections.Counter(lhs) + rhs = "..." if '...' in lhs else "" + rhs = rhs + "".join(filter(is_free, sorted(cnt.elements()))) + return rhs + + +def gen_equation_for_opteinsum(lhs, rhs): + """ + 1. gen rhs if rhs is None + 2. '...' -> 'A' + """ + + def get_used_label(counter): + used = set(counter.elements()) + for c in string.ascii_lowercase: + if c not in used: return c + raise ValueError( + "You have used all `a` - `z`, there can't find a unused for einsum optimization" + ) + + cnt = collections.Counter(lhs) + broadcast_label = get_used_label(cnt) + if rhs is None: + rhs = rhs_inference(lhs) + lhs = lhs.replace("...", broadcast_label) + rhs = rhs.replace("...", broadcast_label) + return lhs + "->" + rhs, broadcast_label + + def einsum_v2(equation, *operands): + """ + einsum v2 implementation. + 1. Implement C++ EinsumOp. + 2. V2 create the EinsumOp to calculate, so just a little verifty work in python. + 3. V2 use opt_einsum.contract_path to optimize the multivariable einsum. + """ + n_op = len(operands) + lhs, rhs, labels = preprocess(equation, *operands) + + if n_op <= 2: + return gen_einsum_op(lhs + '->' + rhs, *operands) + + shapes = parse_fake_shape(lhs, operands, labels) + opt_equation, broadcast_label = gen_equation_for_opteinsum(lhs, rhs) + _, cons = opt_einsum.contract_path(opt_equation, *shapes, einsum_call=True) + var_list = list(operands) + for path in cons: + (a, b), _, eq, *__ = path + assert a > b, "Assume the first var_idx is smaller than the second_idx. opt_einsum can guarantee it." + var_s = [var_list.pop(a), var_list.pop(b)] + eq = eq.replace(broadcast_label, "...") + var_list.append(gen_einsum_op(eq, *var_s)) + assert len( + var_list + ) == 1, "There must be one elements in list, but received %d." % len( + var_list) + return var_list[0] + + +def gen_einsum_op(equation, *operands): + """ + EinsumOp Python Interface: + """ + assert len(operands) <= 2, "Only support two operands in EinsumOp." + if in_dygraph_mode(): + return _C_ops.final_state_einsum(operands, equation) + if _in_legacy_dygraph(): # dygraph return _C_ops.einsum(operands, 'equation', equation) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 35976b6f8715ca8d483393ac60941801883efa08..f078aae9bb6b163c616a03c789ba743e69713ebb 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -561,6 +561,16 @@ func : eigh backward : eigh_grad +- api : einsum + args : (Tensor[] x, str equation) + output : Tensor + infer_meta : + func : EinsumInferMeta + param : [x, equation] + kernel : + func : einsum + backward : einsum_grad + - api : elementwise_pow args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 1e58c19728adcc720615ae362c0f2732b8e75fdb..e044447f87c226d911e06b771c809f98a9828f0e 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -475,6 +475,16 @@ data_transform: skip_transform : out_w, out_w_grad +- backward_api : einsum_grad + forward : einsum (Tensor[] x, str equation) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, str equation) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : UnchangedMultiInferMeta + param : [x] + kernel : + func : einsum_grad + - backward_api : elementwise_pow_grad forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1) diff --git a/python/requirements.txt b/python/requirements.txt index 5f2b788a81a0ad5b8150ee065602e7b643591ea2..e7fc6cd651cb0f7f8a8907fade9177fd8bb17f4b 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,3 +6,4 @@ six decorator astor paddle_bfloat==0.1.2 +opt_einsum==3.3.0