未验证 提交 8c8667f0 编写于 作者: T Tongxin Bai 提交者: GitHub

New Einsum API (#33821)

* OP dot: refactor CPU kernels and get better loop performance.

* Minor fix on code format.

* Fixed minor errors.

* Add new API: einsum

* Update the Einsum unit test.

One case failed with matmul_v2, where the dtype is int64:

a = np.arange(2 * 3 * 1).reshape(2, 3, 1)
b = np.arange(1)
paddle.einsum("...i, ...i", a, b)

* Test cases in test_einsum test floating point dtypes only.

As of now Paddle only supports float/double dtypes in matmul, which is
one of building blocks of this Einsum implementation. We decide not to
test einsum against other dtypes.

* Polish format.

* More formatting.

* Format...

* Einsum: improve test coverage.

* Einsum: bug fixes and more testcases for testing error messages

* Einsum: fix format..

* Einsum: fixed typo and format.

* Einsum: format again...

* Einsum: applied suggested changes.

* Einsum API: improve API documentation.

* Einsum API: apply suggested changes.

* Einsum API: Add dygraph only note.

* Einsum API: Add dygraph only note.

* Einsum API: fixed unittest.
上级 ff4bdac3
...@@ -106,7 +106,8 @@ extern void* mklml_dso_handle; ...@@ -106,7 +106,8 @@ extern void* mklml_dso_handle;
__macro(vmsErf); \ __macro(vmsErf); \
__macro(vmdErf); \ __macro(vmdErf); \
__macro(MKL_Free_Buffers); \ __macro(MKL_Free_Buffers); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads); \
__macro(MKL_Get_Max_Threads);
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
...@@ -238,6 +238,8 @@ from .tensor.search import sort # noqa: F401 ...@@ -238,6 +238,8 @@ from .tensor.search import sort # noqa: F401
from .tensor.to_string import set_printoptions # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401
from .tensor.einsum import einsum # noqa: F401
from .framework.random import seed # noqa: F401 from .framework.random import seed # noqa: F401
from .framework.random import get_cuda_rng_state # noqa: F401 from .framework.random import get_cuda_rng_state # noqa: F401
from .framework.random import set_cuda_rng_state # noqa: F401 from .framework.random import set_cuda_rng_state # noqa: F401
...@@ -509,4 +511,5 @@ __all__ = [ # noqa ...@@ -509,4 +511,5 @@ __all__ = [ # noqa
'standard_normal', 'standard_normal',
'diagonal', 'diagonal',
'broadcast_tensors', 'broadcast_tensors',
'einsum'
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import contextlib
import unittest
import paddle
from paddle.fluid import core
class TestErrors(unittest.TestCase):
def setUp(self):
pass
def test_diagonalize_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
paddle.einsum('...ii->...i', a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
paddle.einsum('i...i', a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
paddle.einsum('i...i->i...', a)
def test_param_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(AssertionError,
('At least one operand is expected.')):
paddle.einsum('ijk')
with self.assertRaisesRegex(AssertionError, (
'Invalid equation: multiple `->` were found.')):
paddle.einsum('i -> j -> k', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 2, "
"but found 3 segments in the label equation.")):
paddle.einsum('i,j,k', a, a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 2, "
"but found 1 segments in the label equation.")):
paddle.einsum('ij -> k', a, a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 1, "
"but found 2 segments in the label equation.")):
paddle.einsum('i, -> k', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the label string '' misses dimensions.")):
paddle.einsum('->', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the label string 'i' misses dimensions.")):
paddle.einsum('i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: _ is not a valid label, "
"which should be letters.")):
paddle.einsum('i_', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: `.` is found outside of an ellipsis.")):
paddle.einsum('i..j', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: `.` is found outside of an ellipsis.")):
paddle.einsum('...k...', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: missing ellipsis in output labels.")):
paddle.einsum('i...->i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: duplicate output labels are found.")):
paddle.einsum('i...->i...i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid operands: label i "
"corresponds to non-broadcastable dimensions.")):
paddle.einsum('ij...,ji...', a, a)
class TestEinsum(unittest.TestCase):
@classmethod
def setUpClass(cls):
np.random.seed(12345)
cls.TEST_SAMPLES = {
"x": np.random.rand(5),
"y": np.random.rand(7),
"A": np.random.rand(4, 5),
"B": np.random.rand(2, 5),
"C": np.random.rand(3, 7),
"D": np.random.rand(3, 4, 5),
"E": np.random.rand(3, 5, 2),
"F": np.random.rand(2, 4, 5, 3),
"G": np.random.rand(4, 2, 5),
"H": np.random.rand(3, 2, 4),
"I": np.random.rand(2, 2),
"J": np.random.rand(1, 3, 5),
"K": np.random.rand(1, 2, 3, 4),
}
def _get_place(self, force_to_use_cpu=False):
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8):
error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
self.assertTrue(
np.allclose(
actual, expect, rtol=rtol, atol=atol),
error_msg.format(paddle.get_device(), expect, actual,
self.__class__.__name__))
def setUp(self):
self.sample = {"paradigm": "i->", "data": ["x"]}
def test_forward(self):
operands = [
TestEinsum.TEST_SAMPLES[operand] for operand in self.sample["data"]
]
expected_result = np.einsum(self.sample["paradigm"], *operands)
equation = self.sample["paradigm"]
with paddle.fluid.dygraph.guard(
self._get_place(force_to_use_cpu=False)):
pd_operands = [paddle.to_tensor(operand) for operand in operands]
result = paddle.einsum(equation, *pd_operands)
self.check_output_equal(result.numpy(), expected_result)
with paddle.fluid.dygraph.guard(self._get_place(force_to_use_cpu=True)):
pd_operands = [paddle.to_tensor(operand) for operand in operands]
result = paddle.einsum(equation, *pd_operands)
self.check_output_equal(result.numpy(), expected_result)
class TestEinsumVectorDot(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->", "data": ["x", "x"]}
class TestEinsumVectorMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->i", "data": ["x", "x"]}
class TestEinsumVectorOuter(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,j->ij", "data": ["x", "y"]}
class TestEinsumMatrixTranspose(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->ji", "data": ["A"]}
class TestEinsumMatrixRowSum(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->j", "data": ["A"]}
class TestEinsumMatrixColSum(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->i", "data": ["A"]}
class TestEinsumMatrixEleMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]}
class TestEinsumMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]}
class TestEinsumMatrixMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,kj->ik", "data": ["A", "B"]}
class TestEinsumMatrixOuter(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,kl->ijkl", "data": ["A", "C"]}
class TestEinsumTensorBMM(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "bij,bjk->bik", "data": ["D", "E"]}
class TestEinsumTensorContract1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->i", "data": ["D", "A"]}
class TestEinsumTensorContract2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,lk->ijl", "data": ["D", "B"]}
class TestEinsumTensorContract3(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "abcd,dfg->abcfg", "data": ["F", "D"]}
class TestEinsumTensorContract4(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->ik", "data": ["D", "A"]}
class TestEinsumTensorContract5(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->ij", "data": ["D", "A"]}
class TestEinsumTensorContract6(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ik, ijk->j", "data": ["A", "G"]}
class TestEinsumTensorContract7(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk, ik->jk", "data": ["G", "A"]}
class TestEinsumEllipsis1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i...->...", "data": ["G"]}
class TestEinsumEllipsis2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,...i->j...", "data": ["A", "H"]}
class TestEinsumEllipsis3(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "k...,jk", "data": ["F", "I"]}
class TestEinsumTestEinsumBilinear(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "bn,anm,bm->ba", "data": ["B", "E", "I"]}
class TestEinsumTestEinsumOthers1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijkl, lmn->kmn", "data": ["F", "H"]}
class TestEinsumTestEinsumOthers2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijkl, lmn->ijn", "data": ["F", "H"]}
class TestEinsumBatch1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "blq,bhlk->bhlqk", "data": ["J", "K"]}
class TestNumpyTests(unittest.TestCase):
def setUp(self):
pass
def _get_place(self, force_to_use_cpu=False):
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8):
error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
self.assertTrue(
np.allclose(
actual, expect, rtol=rtol, atol=atol),
error_msg.format(paddle.get_device(), expect, actual,
self.__class__.__name__))
def check_output(self, eqn, *ops):
expect = np.einsum(eqn, *ops)
with paddle.fluid.dygraph.guard(
self._get_place(force_to_use_cpu=False)):
pd_operands = [paddle.to_tensor(op) for op in ops]
actual = paddle.einsum(eqn, *pd_operands)
self.check_output_equal(actual.numpy(), expect)
def test_sums(self):
for n in range(1, 17):
a = np.arange(n).astype('float')
self.check_output("i->", a)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("...i->...", a)
for n in range(1, 17):
a = np.arange(2 * n).reshape(2, n).astype('float')
self.check_output("i...->...", a)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("i...->...", a)
for n in range(1, 17):
a = np.arange(3 * n).reshape(3, n).astype('float')
b = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("..., ...", a, b)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("...i, ...i", a, b)
for n in range(1, 11):
a = np.arange(n * 3 * 2).reshape(n, 3, 2).astype('float')
b = np.arange(n).astype('float')
self.check_output("i..., i...", a, b)
for n in range(1, 17):
a = (np.arange(3) + 1).astype('float')
b = (np.arange(n) + 1).astype('float')
self.check_output("i,j", a, b)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("ij, j", a, b)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("ji,j", a.T, b.T)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n * 6).reshape(n, 6).astype('float')
self.check_output("ij,jk", a, b)
a = np.arange(12).reshape(3, 4).astype('float')
b = np.arange(20).reshape(4, 5).astype('float')
c = np.arange(30).reshape(5, 6).astype('float')
self.check_output("ij,jk,kl", a, b, c)
a = np.arange(60).reshape(3, 4, 5).astype('float')
b = np.arange(24).reshape(4, 3, 2).astype('float')
self.check_output("ijk, jil -> kl", a, b)
for n in range(1, 25):
a = np.arange(n).astype('float')
self.check_output("...,...", a, a)
self.check_output("i,i", a, a)
p = np.ones((10, 2)).astype('float')
q = np.ones((1, 2)).astype('float')
self.check_output('ij,ij->j', p, q)
x = np.array([2., 3.]).astype('float')
y = np.array([4.]).astype('float')
self.check_output("i, i", x, y)
p = np.ones((1, 5)) / 2
q = np.ones((5, 5)) / 2
self.check_output("...ij,...jk->...ik", p, p)
self.check_output("...ij,...jk->...ik", p, q)
x = np.eye(2).astype('float')
y = np.ones(2).astype('float')
self.check_output("ji,i->", x, y)
self.check_output("i,ij->", y, x)
self.check_output("ij,i->", x, y)
def test_large_nops(self):
a = np.arange(4 * 3 * 1 * 4).reshape(4, 3, 1, 4).astype('float')
self.check_output('a...b,b...c,c...d', a, a, a)
self.check_output('a...b,b...c,c...a', a, a, a)
self.check_output('a...b,b...c,c...a', a, a, a)
self.check_output('...ab,...ba,...ab,...ab', a, a, a, a)
if __name__ == "__main__":
unittest.main()
...@@ -204,6 +204,8 @@ from .array import array_read # noqa: F401 ...@@ -204,6 +204,8 @@ from .array import array_read # noqa: F401
from .array import array_write # noqa: F401 from .array import array_write # noqa: F401
from .array import create_array # noqa: F401 from .array import create_array # noqa: F401
from .einsum import einsum # noqa: F401
#this list used in math_op_patch.py for _binary_creator_ #this list used in math_op_patch.py for _binary_creator_
tensor_method_func = [ #noqa tensor_method_func = [ #noqa
'matmul', 'matmul',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from ..fluid.layers import reshape, transpose
from .linalg import matmul
from .manipulation import squeeze, unsqueeze
from .math import multiply
from .math import sum as paddle_sum
from paddle.common_ops_import import dygraph_only
__all__ = []
def parse_op_labels(labelstr, operand):
'''
Parse labels for an input operand.
Parameters
----------
labelstr:
the input label string
operand:
the input operand
Returns
-------
the input operand's full label string in which all anonymous dimensions are
labeled in dots.
'''
# Sanity checks
for c in labelstr.replace('.', ''):
assert c.isalpha(), (
f"Invalid equation: {c} is not a valid label, which should be letters."
)
assert labelstr.replace('...', '', 1).find('.') == -1, (
f"Invalid equation: `.` is found outside of an ellipsis.")
# Check shape. Note, in Paddle a tensor rank is always nonzero
ndims = len(operand.shape)
assert ndims > 0
full_labelstr = labelstr.replace('...', '.' * (ndims - len(labelstr) + 3))
assert len(full_labelstr) == ndims, (
f"Invalid equation: the label string '{labelstr}' misses dimensions.")
return full_labelstr
def parse_labels(labelstr, operands):
'''
Parse label strings for all input operands.
Parameters
----------
labelstr:
The equation's label string
operands:
The input operands
Returns
-------
list of full label strings for all input operands
'''
nop_labels = labelstr.split(',')
assert len(nop_labels) == len(operands), (
f"Invalid equation: the number of operands is {len(operands)}, "
f"but found {len(nop_labels)} segments in the label equation.")
return list(map(parse_op_labels, nop_labels, operands))
def validate_rhs(rhs, input_labels, n_bcast_dims):
'''
Check whether the equation's right hand side is valid
'''
# Sanity check.
if n_bcast_dims > 0:
assert '...' in rhs, (
f"Invalid equation: missing ellipsis in output labels.")
rhs = rhs.replace('...', '')
rhs_set = set(rhs)
# Hidden assumption: availble labels don't include '.'
assert '.' not in input_labels
# Verify that output labels all come from the set of input labels
non_input_labels = rhs_set.difference(input_labels)
assert not non_input_labels, (
f"Invalid equation: "
f"output label {sorted(non_input_labels)} not used by any input.")
# Verify that output labels are not duplicate
assert len(rhs) == len(rhs_set), (
f"Invalid equation: duplicate output labels are found.")
# '''
# Tests if the two operands can perform a broadcast operation on the given ranges of dimensions.
# We follow the Numpy broadcasting convention which states that, by lining up the shape arrays
# starting from the right most dimension, all the aligned dimensions either have equal sizes or
# one of them is sized one.
# Parameters
# ----------
# args:
# *args unpacks into operand one's axes range, shape, operand two's axes range, shape
# f:
# if available, is used as a callback for postprocessing the aligned operand dimensions.
# '''
# xran, xshape, yran, yshape = args
#
# xran_inv, yran_inv = xran[::-1], yran[::-1]
#
# for xi, yi in zip(xran_inv, yran_inv):
# xs, ys = xshape[xi], yshape[yi]
# cond = xs == ys or xs == 1 or ys == 1
# if not cond:
# return False
#
# if not f:
# return True
#
# # Apply the callback to each aligned dimension pair
# for xi, yi in zip(xran_inv, yran_inv):
# f(xi, yi)
def build_view(in_labels, out_labels):
'''
Build an inverse map of dimension indices. Three conditions must hold for
the result to be meaningful.
First, no duplicate letter labels in each label string.
Second, the number of dots in dimout_labels >= that in in_labels.
Third, dots are contiguous in each label string.
Parameters
----------
in_labels:
The dimension labels to map to
out_labels:
The dimension labels to map from
Returns
-------
The inverse map from out_labels to in_labels. The length of the inverse map equals that of
out_labels. -1 is filled if there's no matching intput dimension for a specific label.
Examples
--------
in_labels = 'ij..', out_labels = '..ji'
inv_map = [2, 3, 1, 0]
in_labels = 'ij..', out_labels = '..kji'
inv_map = [2, 3, -1, 1, 0]
'''
inv_map = [-1] * len(out_labels)
# First build the broadcast dimension mapping
# Find the broadcast index range in out_labels
r = re.search(r'\.+', out_labels)
if r:
start, end = r.start(), r.end()
s = re.search(r'\.+', in_labels)
# fill the broadcast dimension indices from right to left.
if s:
for ax, dim in zip(
range(start, end)[::-1], range(s.start(), s.end())[::-1]):
inv_map[ax] = dim
# Now work on non-broadcast dimensions
if r:
it = itertools.chain(range(start), range(end, len(out_labels)))
else:
it = iter(range(len(out_labels)))
for i in it:
inv_map[i] = in_labels.find(out_labels[i])
return inv_map
def build_global_view(nop_labels, rhs, n_bcast_dims):
'''
Build the global view, which is a layout of all dimension labels
plus an index table that maps from the layout to the dimensions
in each operand. In the global view, the dimensions are arranged
such that output ones are put on the left and contraction ones
are put on the right.
Parameters
----------
nop_labels:
The input full label strings of all input operands
rhs:
The equation right hand side
n_bcast_dims:
The maxium number of broadcast dimensions
Returns
-------
A tuple of g_labels, g_view, g_nout, g_count
g_labels:
the layout of all labels in a string
g_view:
the index table
g_nout:
the number of output dimensions
g_count:
the counter array for dimension contractions
'''
# Put all labels in alphabetical order
concat = sorted(''.join(nop_labels).replace('.', ''))
labels, count = [], []
for a, b in zip(['.'] + concat, concat):
if a != b:
labels.append(b)
count.append(1)
else:
count[-1] += 1
if rhs != None:
validate_rhs(rhs, labels, n_bcast_dims)
g_labels_out = rhs.replace('...', '.' * n_bcast_dims)
else:
g_labels_out = '.' * n_bcast_dims + ''.join(
l for l, c in zip(labels, count) if c == 1)
for i in range(len(count))[::-1]:
if labels[i] in g_labels_out:
labels.pop(i)
count.pop(i)
g_labels_sum = ''.join(labels)
g_labels = g_labels_out + g_labels_sum
g_view = list(map(lambda i: build_view(i, g_labels), nop_labels))
g_nout = len(g_labels_out)
g_count = count
return g_labels, g_view, g_nout, g_count
def build_global_shape(g_view, g_labels, op_shapes):
'''
The global shape is the shape of all dimensions rearranged and broadcasting
to the global view. It's a reference data structure for einsum planning.
Parameters
----------
g_view:
the global view
op_shapes:
the shapes of the all operands
Returns
-------
g_shape:
the global shape vector
g_masks:
list of shape masks for each operand. A dimension's shape mask is a boolean
indicating whether its size > 1, in other words, it's not squeezable
'''
view_shapes = []
g_masks = []
for view, op_shape in zip(g_view, op_shapes):
view_shapes.append([op_shape[dim] if dim > -1 else 1 for dim in view])
g_shape = [set(sizes_per_ax) - {1} for sizes_per_ax in zip(*view_shapes)]
non_bcastable = [ax for ax, sizes in enumerate(g_shape) if len(sizes) > 1]
assert not non_bcastable, (
f"Invalid operands: label {g_labels[non_bcastable[0]]} "
f"corresponds to non-broadcastable dimensions.")
g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape]
g_masks = [[s > 1 for s in view_shape] for view_shape in view_shapes]
return g_shape, g_masks
def dim_strides(shape):
'''
Returns the dimension strides for a tensor shape
'''
strides = []
stride = 1
for size in shape[::-1]:
strides.append(stride)
stride = stride * size
return strides
def create_view(operand, *view_def):
'''
Create and materialize a view.
Parameters
----------
operand:
the base tensor operand
view_def:
include two lists which define the view's dimension sizes and strides
'''
assert False, f'Diagonal and trace not implemented yet.'
view_shape, view_strides = view_def
return operand.create_view(view_shape, view_strides)
def has_duplicated_labels(labels):
'''
Returns True if there is any duplicate label.
'''
labels = labels.replace('.', '')
return len(labels) > len(set(labels))
def diagonalize(labels, operand):
'''
Merges dimensions with duplicate labels.
For those dimensions with duplicate labels, merge them into one dimension
which represents the diagonal elements. That requires the duplicate labeled
dimensions equal sized. The order of dimensions is kept unchanged up to
the left-most appearance of each label.
Examples
--------
'ijj...i' would be merged into 'ij...'
'''
if not has_duplicated_labels(labels):
return labels, operand
strides = dim_strides(operand.shape)
shape = operand.shape
new_labels = []
new_shape = []
new_strides = []
for ax, l in enumerate(labels):
if l == '.' or l not in new_labels:
# not duplicate
new_labels.append(l)
new_strides.append(strides[ax])
new_shape.append(shape[ax])
else:
# duplicate label
diag_ax = new_labels.index(l)
new_strides[diag_ax] += strides[ax]
# Call framework API to build a new tensor
new_op = create_view(operand, new_shape, new_strides)
return new_labels, new_op
def prod(iter, default=1):
if len(iter):
res = 1
for s in iter:
res *= s
return res
return default
def plan_reduce(plan, op, reduce_dims, keepdim):
'''
Add reduce to the plan
'''
varname = f'op{op}'
f = lambda var, dims: paddle_sum(var, dims, keepdim=keepdim)
step = f, [varname], varname, reduce_dims
plan.add_step(step)
def plan_scalar_prod(plan, op1, op2):
varnames = [f'op{op1}', f'op{op2}']
f = lambda var1, var2: paddle_sum(var1) * var2
step = f, varnames, varnames[1]
plan.add_step(step)
def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
'''
plan matmul
'''
# Transpose and re-shape op1 and op2 in I, J1, K and I, J2, K
# Then apply matmul(x, y, transpose_x=False, tranpose_y=True)
var1, var2 = f'op{op1}', f'op{op2}'
op1_view, op2_view = [g_view[op] for op in (op1, op2)]
# Note, I may index into -1
I1_dims = [op1_view[ax] for ax in I if op1_view[ax] >= 0]
I2_dims = [op2_view[ax] for ax in I if op2_view[ax] >= 0]
J1_dims = [op1_view[ax] for ax in J1]
J2_dims = [op2_view[ax] for ax in J2]
K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K]
op1_mask, op2_mask = [g_op_masks[op] for op in (op1, op2)]
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]
I1_shape, J1_shape, K1_shape = [[op1_vshape[ax] for ax in axes]
for axes in (I, J1, K)]
I2_shape, J2_shape, K2_shape = [[op2_vshape[ax] for ax in axes]
for axes in (I, J2, K)]
K1_size, J1_size, J2_size = prod(K1_shape), prod(J1_shape), prod(J2_shape)
perm1 = I1_dims + J1_dims + K1_dims
perm2 = I2_dims + J2_dims + K2_dims
if any(i != dim for i, dim in enumerate(perm1)):
# print(f'perm1: {perm1}')
step = transpose, [var1], var1, perm1
plan.add_step(step)
if any(i != dim for i, dim in enumerate(perm2)):
# print(f'perm2: {perm2}')
step = transpose, [var2], var2, perm2
plan.add_step(step)
# In case of no K... dimensions, do a broadcast
if not K:
# unsqueeze operands include J1...J2... dimensions
if J2:
fill_start = len(I2_dims) + len(J1)
fill_end = fill_start + len(J2)
fill = list(range(fill_start, fill_end))
step = unsqueeze, [var1], var1, fill
plan.add_step(step)
if J1:
fill_start = len(I2_dims)
fill_end = fill_start + len(J1)
fill = list(range(fill_start, fill_end))
step = unsqueeze, [var2], var2, fill
plan.add_step(step)
# make broadcast
step = multiply, [var1, var2], var2
plan.add_step(step)
# K... are there, let's reason about I... and J...
# In case I... and J... are empty, do the vector-vector version of matmul
elif not I and not J1 and not J2:
# merge K dimensions
if len(K) > 1:
for var in var1, var2:
step = reshape, [var], var, [K1_size]
plan.add_step(step)
# Build vector-vector matmul
step = matmul, [var1, var2], var2
plan.add_step(step)
# General case, there are K... and some I... and J..., the actual operation will be
# matrix-vector or matrix-matrix multiplies, depending on the operands' shapes.
else:
# Merge J dims and K dims by reshaping
merged_shape1 = I1_shape + [J1_size] + [K1_size]
merged_shape2 = I2_shape + [J2_size] + [K1_size]
step = reshape, [var1], var1, merged_shape1
plan.add_step(step)
step = reshape, [var2], var2, merged_shape2
plan.add_step(step)
# Matmul
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# The result shape is in I..., J1, J2. Let's reshape back to known dimensions
# Note, this is static deduction, not by reading the tensor shape at runtime
result_shape = [1] * len(I)
for i, ax in enumerate(I):
result_shape[i] = max(op1_vshape[ax], op2_vshape[ax])
if J1:
result_shape += J1_shape
if J2:
result_shape += J2_shape
# Need a scalar dimension somehow
if result_shape:
step = reshape, [var2], var2, result_shape
plan.add_step(step)
# Wrap up, updating auxiliary data
# Updating g_mask for I and J axes
for i, ax in enumerate(I + J1 + J2):
op2_mask[ax] = (result_shape[i] > 1)
for ax in K:
op2_mask[ax] = False
for ax in range(len(op2_view)):
op2_view[ax] = -1
dim = 0
for ax in I + J1 + J2:
op2_view[ax], dim = dim, dim + 1
def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,
n_bcast):
'''
Plan various kinds of summation
'''
op1_view, op2_view = g_view[op1], g_view[op2]
op1_mask, op2_mask = g_op_masks[op1], g_op_masks[op2]
ndim = len(op1_view)
nout = ndim - len(g_count)
count = [0] * nout + g_count
I, K, J1, J2 = list(range(n_bcast)), [], [], []
for ax, dim1, dim2 in zip(
range(n_bcast, ndim), op1_view[n_bcast:], op2_view[n_bcast:]):
if (dim1 != -1) != (dim2 != -1):
if dim1 != -1:
J1.append(ax)
else:
J2.append(ax)
elif dim1 != -1:
fold = int(op1_mask[ax]) + int(op2_mask[ax])
if ax >= nout and fold == count[ax]:
# Ready to fold the dimensions
K.append(ax)
count[ax] -= fold
else:
I.append(ax)
count[ax] -= max(fold - 1, 0)
# Update g_count
g_count[:] = count[nout:]
# Now it's OK to merge the K dims as the same shape holds
# print(f'I: {I} J1: {J1} J2: {J2} K: {K}')
plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K)
def rearrange(axes):
perm, fill = [], []
for ax, dim in enumerate(axes):
if dim < 0:
fill.append(ax)
else:
perm.append(dim)
# Trivial permutation returns []
if all(i == dim for i, dim in enumerate(perm)):
perm = []
return perm, fill
def plan_broadcast(plan, operands, nop_axes):
'''
Plan broadcast across
'''
nop = len(operands)
varnames = [f'op{i}' for i in range(nop)]
for i, op_axes in zip(range(nop), nop_axes):
# Re-arrange the dimesions according to the global layout
perm, fill = rearrange(op_axes)
var = varnames[i]
if perm:
step = transpose, [var], var, perm
plan.add_step(step)
if fill:
step = unsqueeze, [var], var, fill
plan.add_step(step)
def f(*args):
expr = ' * '.join(varnames)
return eval(expr, dict(zip(varnames, args)))
step = f, varnames, None
plan.add_step(step)
class Plan:
def __init__(self):
self.env = {}
self.steps = []
def add_step(self, step):
self.steps.append(step)
def get_var(self, varname):
return self.env[varname] if varname in self.env else None
def set_var(self, varname, var):
self.env[varname] = var
def show(self):
res = None
for f, in_varnames, out_varname, *args in self.steps:
print(repr((out_varname, f, *in_varnames, *args)))
return res
def execute(self):
res = None
for f, in_varnames, out_varname, *args in self.steps:
res = f(*map(self.get_var, in_varnames), *args)
if out_varname:
self.set_var(out_varname, res)
return res
def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
'''
Plans the actual execution steps.
Results
-------
the execution plan
'''
nop = len(operands)
ndim = len(g_view[0])
nout = ndim - len(g_count)
# Initialize a plan with an environment
plan = Plan()
op_names = [f'op{i}' for i in range(nop)]
list(map(plan.set_var, op_names, operands))
# In case no dimensions to combine, do broadcast straight across
if not g_count:
plan_broadcast(plan, operands, g_view)
return plan
# Down count axis >= nout and degenerate dimensions (masked is not set)
for view, mask in zip(g_view, g_op_masks):
down_count = [
1 if (dim > -1 and not masked) else 0
for dim, masked in zip(view[nout:], mask[nout:])
]
for i, d in enumerate(down_count):
g_count[i] -= d
# Reduce any dimension for which g_mask is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_op_masks):
to_reduce = []
for dim, masked, count in zip(view[nout:], mask[nout:], g_count):
to_reduce.append(dim if (masked and count == 1) else -1)
reduce_dims = list(filter(lambda x: x > -1, to_reduce))
if reduce_dims:
plan_reduce(plan, i, reduce_dims, keepdim=True)
# Unset mask and decrease g_count for the reduced dimensions
for i, d in enumerate(to_reduce):
ax = i + nout
mask[ax] = mask[ax] and (d == -1)
g_count[i] -= 0 if d == -1 else 1
# Plan the summations over the operand sequence
for i in range(nop):
# plan a single step
if i == 0:
continue
# We'd like to arrange the dimensions in the following way:
# [I... J... K...]
# [I... J... K...]
# where
# I... are aligned and not to be combined immediately
# J... are not aligned and not to be combined immediately
# K... are aligned and should be immediately combined
# At this point the non-trivial broadcast dimensinos in K are already reduced
# and removed. That means all K dimensions are aligned and their sizes are not 1.
# We then inspect the layout of I,J,K plus the above observation to make
# specializatoin decisions. The current strategy is set as follows:
# (1) if I... J... K... are all empty, it's multiplying a scalar
# (2) if K... are empty, better use a broadcast
# (3) if I... J... empty and K... not empty, a vector-vector multiply (or a dot)
# (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul
# Resolve the summation kind: dot, matmul or *
if not any(g_op_masks[i - 1]):
# op1 is a scalar
plan_scalar_prod(plan, i - 1, i)
else:
plan_summation(plan, g_view, i - 1, i, g_op_masks, g_shape, g_count,
n_bcast)
# for ax, dim in enumerate(g_view[nop-1][:nout]):
# assert dim == ax
assert all(not masked for masked in g_op_masks[nop - 1][nout:])
view = g_view[-1]
if any(ax != dim for ax, dim in enumerate(view[:nout])):
perm = [dim for dim in view if dim >= 0]
varname = f'op{nop-1}'
step = transpose, [varname], varname, perm
plan.add_step(step)
dim = 0
for ax, d in enumerate(view):
if d != -1:
view[ax], dim = dim, dim + 1
squeeze_dims = [dim for dim in view[nout:] if dim != -1]
if squeeze_dims:
# plan_reduce(plan, nop-1, reduce_dims, keepdim=False)
varname = f'op{nop-1}'
step = squeeze, [varname], varname, squeeze_dims
plan.add_step(step)
return plan
@dygraph_only
def einsum(equation, *operands):
r"""
einsum(equation, *operands)
The current version of this API should be used in dygraph only mode.
Einsum offers a tensor operation API which allows using the Einstein summation
convention or Einstain notation. It takes as input one or multiple tensors and
produces as output one tensor.
Einsum is able to perform a variety of tensor operations. Following lists a few:
- for single operand
- trace
- diagonal
- transpose
- sum
- for double operands
- dot
- outer
- broadcasting and elementwise multiply
- matrix multiply
- batched matrix multiply
- for many operads
- broadcasting multiply
- chained matrix multiply
**The summation notation**
- The tensor dimensions are labeled using uncased English letters. E.g., `ijk`
relates to a three dimensional tensor whose dimensions are labeled i, j, and k.
- The equation is `,` separated into terms, each being a distinct input's
dimension label string.
- Ellipsis `...` enables broadcasting by automatically converting the unlabeled
dimensions into broadcasting dimensions.
- Singular labels are called free labels, duplicate are dummy labels. Dummy labeled
dimensions will be reduced and removed in the output.
- Output labels can be explicitly specified on the right hand side of `->` or omitted.
In the latter case, the output labels will be inferred from the input labels.
- Inference of output labels
- Broadcasting label `...`, if present, is put on the leftmost position.
- Free labels are reordered alphabetically and put after `...`.
- On explicit output labels
- If broadcasting is enabled, then `...` must be present.
- The output labels can be an empty, an indication to output as a scalar
the sum over the original output.
- Non-input labels are invalid.
- Duplicate labels are invalid.
- For any dummmy label which is present for the output, it's promoted to
a free label.
- For any free label which is not present for the output, it's lowered to
a dummy label.
- Examples
- '...ij, ...jk',where i and k are free labels, j is dummy. The output label
string is '...ik'
- 'ij -> i', where i is a free label and j is a dummy label.
- '...ij, ...jk -> ...ijk',where i, j and k are all free labels.
- '...ij, ...jk -> ij', an invalid equation since `...` is not present for
the output.
**The summation rule**
The summation procedure can be outlined as follows, although the actual steps taken
may vary significantly due to implementation specific optimization.
- Step 1: preparation for broadcasting, that is, transposing and unsqueezing
the input operands to have each resulting dimension identically labeled across
all the input operands.
- Step 2: broadcasting multiply all the resulting operands from step 1.
- Step 3: reducing dummy labeled dimensions.
- Step 4: transposing the result tensor to match the output labels.
**On trace and diagonal**
The trace and diagonal are planned yet unimplemented features.
Args:
equation (`str`):
The summation terms using the Einstein summation notation.
operands (`list|Tensor`):
The input tensors over which to compute the Einstein summation. The number of
operands should equal the number of input terms in the equation.
Returns:
result (`Tensor`): the result tensor.
Examples:
.. code-block:: python
import paddle
paddle.seed(102)
x = paddle.rand([4])
y = paddle.rand([5])
# sum
print(paddle.einsum('i->', x))
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# 1.95791852)
# dot
print(paddle.einsum('i,i->', x, x))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [1.45936954])
# outer
print(paddle.einsum("i,j->ij", x, y))
# Tensor(shape=[4, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.00079869, 0.00120950, 0.00136844, 0.00187187, 0.00192194],
# [0.23455200, 0.35519385, 0.40186870, 0.54970956, 0.56441545],
# [0.11773264, 0.17828843, 0.20171674, 0.27592498, 0.28330654],
# [0.32897076, 0.49817693, 0.56364071, 0.77099484, 0.79162055]])
A = paddle.rand([2, 3, 2])
B = paddle.rand([2, 2, 3])
# transpose
print(paddle.einsum('ijk->kji', A))
# Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[0.95649719, 0.49684682],
# [0.80071914, 0.46258664],
# [0.49814570, 0.33383518]],
#
# [[0.07637714, 0.29374704],
# [0.51470858, 0.51907635],
# [0.99066722, 0.55802226]]])
# batch matrix multiplication
print(paddle.einsum('ijk, ikl->ijl', A,B))
# Tensor(shape=[2, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[0.32172769, 0.50617385, 0.41394392],
# [0.51736701, 0.49921003, 0.38730967],
# [0.69078457, 0.42282537, 0.30161136]],
#
# [[0.32043904, 0.18164253, 0.27810261],
# [0.50226176, 0.24512935, 0.39881429],
# [0.51476848, 0.23367381, 0.39229113]]])
# Ellipsis transpose
print(paddle.einsum('...jk->...kj', A))
# Tensor(shape=[2, 2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[0.95649719, 0.80071914, 0.49814570],
# [0.07637714, 0.51470858, 0.99066722]],
#
# [[0.49684682, 0.46258664, 0.33383518],
# [0.29374704, 0.51907635, 0.55802226]]])
# Ellipsis batch matrix multiplication
print(paddle.einsum('...jk, ...kl->...jl', A,B))
# Tensor(shape=[2, 3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[0.32172769, 0.50617385, 0.41394392],
# [0.51736701, 0.49921003, 0.38730967],
# [0.69078457, 0.42282537, 0.30161136]],
#
# [[0.32043904, 0.18164253, 0.27810261],
# [0.50226176, 0.24512935, 0.39881429],
# [0.51476848, 0.23367381, 0.39229113]]])
"""
nop = len(operands)
assert nop > 0, "At least one operand is expected."
# Part the equation to left hand side and right hand side
lhs, *rhs = equation.lower().replace(' ', '').split('->')
assert len(rhs) < 2, "Invalid equation: multiple `->` were found."
# Note, we distinguish between 'ij->' and 'ij' by setting rhs to '' and None
rhs = rhs[0] if rhs else None
# Parse labels for each operand and count the number of occurrences for each alphabet label
nop_labels = parse_labels(lhs, operands)
# Diagonalize the operands which have duplicate labels
nop_labels, operands = list(zip(*map(diagonalize, nop_labels, operands)))
# To handle broadcasting, we should first know how many dimensions are there
# We need to use that number to generate output labels
# e.g. 1 for ['ij', 'i.', '.k']
n_bcast_dims = max(map(lambda s: s.count('.'), nop_labels))
# Build the data structures for planning. It's helpful to think of all the operands
# broadcasting together from a global view. In this view, dimensions from multiple
# operands are mapped to the same position if they are labeled uniquely. Broadcasting
# dimensions are mapped to adjacent positions with the right bound fixed. Subject to
# each operand, the map is injective but for all operands the map is on-to.
# g_labels:
# The labels of the global view
# g_view:
# Includes a list of maps from each operand's dimensions to the global view's dimensions
# which we refer to as ax or axes in the code to distinguish from operand's dims
# g_shape:
# The shape of the global view. The size of each dimension is what the aligned dimensions
# should broadcast to
# g_nout:
# Number of output axes
# g_op_masks
# A list of masks that specify each operand's non-trivial dimensions
# g_count
# Counting how many non-trivial dimensions remain for each ax
g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs,
n_bcast_dims)
g_shape, g_op_masks = build_global_shape(g_view, g_labels,
[op.shape for op in operands])
# Now we're ready to build up an execution plan
args = operands, g_view, g_shape, g_op_masks, g_count, n_bcast_dims
plan = plan_einsum(*args)
result = plan.execute()
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册