From 187fcfa3f8a5d673f6328801fa4a346a320d64d7 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Tue, 15 Mar 2022 17:18:40 +0800 Subject: [PATCH] [einsum] refactored and supporting unknown shapes in static mode (#40360) * formatted. * Remove dead code. * Fix error message in the unit test. * polish formats. * [Einsum] fix bugs. --- .../fluid/tests/unittests/test_einsum.py | 57 +++- python/paddle/tensor/einsum.py | 260 ++++++------------ 2 files changed, 142 insertions(+), 175 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_einsum.py b/python/paddle/fluid/tests/unittests/test_einsum.py index 13e763bee63..43b5ce96a39 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum.py +++ b/python/paddle/fluid/tests/unittests/test_einsum.py @@ -26,14 +26,14 @@ class TestErrors(unittest.TestCase): def test_diagonalize_errors(self): a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = paddle.to_tensor(a) - with self.assertRaisesRegex(AssertionError, ( - 'Diagonal and trace not implemented yet.')): + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): paddle.einsum('...ii->...i', a) - with self.assertRaisesRegex(AssertionError, ( - 'Diagonal and trace not implemented yet.')): + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): paddle.einsum('i...i', a) - with self.assertRaisesRegex(AssertionError, ( - 'Diagonal and trace not implemented yet.')): + with self.assertRaisesRegex(AssertionError, + ('Duplicate labels are not supported.')): paddle.einsum('i...i->i...', a) def test_param_errors(self): @@ -396,6 +396,51 @@ class TestNumpyTests(unittest.TestCase): self.check_output('a...b,b...c,c...a', a, a, a) self.check_output('...ab,...ba,...ab,...ab', a, a, a, a) + def test_static_graph(self): + paddle.enable_static() + fluid = paddle.fluid + if fluid.core.is_compiled_with_cuda(): + self.place = fluid.CUDAPlace(0) + else: + self.place = fluid.CPUPlace() + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + a = paddle.static.data( + name='a', shape=[3, None, None, None], dtype='float') + b = paddle.static.data( + name='b', shape=[2, None, None, None], dtype='float') + c = paddle.static.data( + name='c', shape=[None, None, 2, None], dtype='float') + d = paddle.static.data( + name='d', shape=[None, None, 5], dtype='float') + e = paddle.static.data( + name='e', shape=[None, 2, None], dtype='float') + + outs = [] + outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b)) + outs.append(paddle.einsum('...ik, ...j', c, d)) + outs.append(paddle.einsum('...kj, ...ik', d, e)) + outs.append(paddle.einsum('ijk..., ikj', c, e)) + outs.append(paddle.einsum('ijk..., ikj->...ij', c, e)) + exe = fluid.Executor(self.place) + exe.run(startup) + a = np.arange(72).reshape(3, 2, 3, 4).astype('float') + b = np.arange(48).reshape(2, 2, 3, 4).astype('float') + c = np.arange(48).reshape(2, 3, 2, 4).astype('float') + d = np.arange(30).reshape(2, 3, 5).astype('float') + e = np.arange(12).reshape(2, 2, 3).astype('float') + feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e} + actual = exe.run(main, feed=feeds, fetch_list=[outs]) + expect = [] + expect.append(np.einsum("ibnd,jbnd->bnij", a, b)) + expect.append(np.einsum('...ik, ...j', c, d)) + expect.append(np.einsum('...kj, ...ik', d, e)) + expect.append(np.einsum('ijk..., ikj', c, e)) + expect.append(np.einsum('ijk..., ikj->...ij', c, e)) + for a, e in zip(actual, expect): + self.check_output_equal(a, e) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 040480c26fa..06c2a82fd69 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -13,9 +13,10 @@ # limitations under the License. import itertools +import numpy as np import re -from .linalg import matmul, transpose +from .linalg import dot, matmul, transpose from .manipulation import squeeze, unsqueeze, reshape from .math import multiply from .math import sum as paddle_sum @@ -111,36 +112,6 @@ def validate_rhs(rhs, input_labels, n_bcast_dims): f"Invalid equation: duplicate output labels are found.") -# ''' -# Tests if the two operands can perform a broadcast operation on the given ranges of dimensions. -# We follow the Numpy broadcasting convention which states that, by lining up the shape arrays -# starting from the right most dimension, all the aligned dimensions either have equal sizes or -# one of them is sized one. -# Parameters -# ---------- -# args: -# *args unpacks into operand one's axes range, shape, operand two's axes range, shape -# f: -# if available, is used as a callback for postprocessing the aligned operand dimensions. -# ''' -# xran, xshape, yran, yshape = args -# -# xran_inv, yran_inv = xran[::-1], yran[::-1] -# -# for xi, yi in zip(xran_inv, yran_inv): -# xs, ys = xshape[xi], yshape[yi] -# cond = xs == ys or xs == 1 or ys == 1 -# if not cond: -# return False -# -# if not f: -# return True -# -# # Apply the callback to each aligned dimension pair -# for xi, yi in zip(xran_inv, yran_inv): -# f(xi, yi) - - def build_view(in_labels, out_labels): ''' Build an inverse map of dimension indices. Three conditions must hold for @@ -291,39 +262,12 @@ def build_global_shape(g_view, g_labels, op_shapes): g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape] - g_masks = [[s > 1 for s in view_shape] for view_shape in view_shapes] + g_masks = [[s > 1 or s == -1 for s in view_shape] + for view_shape in view_shapes] return g_shape, g_masks -def dim_strides(shape): - ''' - Returns the dimension strides for a tensor shape - ''' - strides = [] - stride = 1 - for size in shape[::-1]: - strides.append(stride) - stride = stride * size - return strides - - -def create_view(operand, *view_def): - ''' - Create and materialize a view. - - Parameters - ---------- - operand: - the base tensor operand - view_def: - include two lists which define the view's dimension sizes and strides - ''' - assert False, f'Diagonal and trace not implemented yet.' - view_shape, view_strides = view_def - return operand.create_view(view_shape, view_strides) - - def has_duplicated_labels(labels): ''' Returns True if there is any duplicate label. @@ -337,46 +281,17 @@ def diagonalize(labels, operand): Merges dimensions with duplicate labels. For those dimensions with duplicate labels, merge them into one dimension - which represents the diagonal elements. That requires the duplicate labeled - dimensions equal sized. The order of dimensions is kept unchanged up to - the left-most appearance of each label. + which represents the diagonal elements. This requires the dimensions with + duplicate labels are equal sized. Examples -------- 'ijj...i' would be merged into 'ij...' ''' - if not has_duplicated_labels(labels): - return labels, operand - - strides = dim_strides(operand.shape) - shape = operand.shape - new_labels = [] - new_shape = [] - new_strides = [] - - for ax, l in enumerate(labels): - if l == '.' or l not in new_labels: - # not duplicate - new_labels.append(l) - new_strides.append(strides[ax]) - new_shape.append(shape[ax]) - else: - # duplicate label - diag_ax = new_labels.index(l) - new_strides[diag_ax] += strides[ax] + assert not has_duplicated_labels(labels), ( + f'Duplicate labels are not supported.') - # Call framework API to build a new tensor - new_op = create_view(operand, new_shape, new_strides) - return new_labels, new_op - - -def prod(iter, default=1): - if len(iter): - res = 1 - for s in iter: - res *= s - return res - return default + return labels, operand def plan_reduce(plan, op, reduce_dims, keepdim): @@ -408,102 +323,108 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K): op1_view, op2_view = [g_view[op] for op in (op1, op2)] - # Note, I may index into -1 - I1_dims = [op1_view[ax] for ax in I if op1_view[ax] >= 0] - I2_dims = [op2_view[ax] for ax in I if op2_view[ax] >= 0] - J1_dims = [op1_view[ax] for ax in J1] - J2_dims = [op2_view[ax] for ax in J2] - K1_dims = [op1_view[ax] for ax in K] - K2_dims = [op2_view[ax] for ax in K] + I1 = [idx for idx in I if op1_view[idx] >= 0] + I2 = [idx for idx in I if op2_view[idx] >= 0] + op1_view = np.array(op1_view) + op1_dims = op1_view[I1 + J1 + K] - op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)] - op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)] - op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)] - - I1_shape, J1_shape, K1_shape = [[op1_vshape[ax] for ax in axes] - for axes in (I, J1, K)] - I2_shape, J2_shape, K2_shape = [[op2_vshape[ax] for ax in axes] - for axes in (I, J2, K)] + op2_view = np.array(op2_view) + op2_dims = op2_view[I2 + J2 + K] - K1_size, J1_size, J2_size = prod(K1_shape), prod(J1_shape), prod(J2_shape) + op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)] + op1_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op1_mask)]) + op2_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op2_mask)]) + vshape = np.maximum(op1_vshape, op2_vshape) - perm1 = I1_dims + J1_dims + K1_dims - perm2 = I2_dims + J2_dims + K2_dims + i1, i2, j1, j2, k = map(len, (I1, I2, J1, J2, K)) - if any(i != dim for i, dim in enumerate(perm1)): + if any(op1_dims != np.arange(len(op1_dims))): # print(f'perm1: {perm1}') - step = transpose, [var1], var1, perm1 + step = transpose, [var1], var1, list(op1_dims) plan.add_step(step) - if any(i != dim for i, dim in enumerate(perm2)): + if any(op2_dims != np.arange(len(op2_dims))): # print(f'perm2: {perm2}') - step = transpose, [var2], var2, perm2 + step = transpose, [var2], var2, list(op2_dims) plan.add_step(step) - # In case of no K... dimensions, do a broadcast - if not K: - # unsqueeze operands include J1...J2... dimensions - if J2: - fill_start = len(I2_dims) + len(J1) - fill_end = fill_start + len(J2) - fill = list(range(fill_start, fill_end)) - step = unsqueeze, [var1], var1, fill - plan.add_step(step) - if J1: - fill_start = len(I2_dims) - fill_end = fill_start + len(J1) - fill = list(range(fill_start, fill_end)) - step = unsqueeze, [var2], var2, fill - plan.add_step(step) - # make broadcast - step = multiply, [var1, var2], var2 - plan.add_step(step) - # K... are there, let's reason about I... and J... - # In case I... and J... are empty, do the vector-vector version of matmul - elif not I and not J1 and not J2: - # merge K dimensions - if len(K) > 1: - for var in var1, var2: - step = reshape, [var], var, [K1_size] - plan.add_step(step) - # Build vector-vector matmul - step = matmul, [var1, var2], var2 - plan.add_step(step) - # General case, there are K... and some I... and J..., the actual operation will be - # matrix-vector or matrix-matrix multiplies, depending on the operands' shapes. - else: - # Merge J dims and K dims by reshaping - merged_shape1 = I1_shape + [J1_size] + [K1_size] - merged_shape2 = I2_shape + [J2_size] + [K1_size] + # Check if conditions hold for turnning the operation into a matmul + if j1 + j2 > 0 and k > 0 and -1 not in np.concatenate( + (op1_vshape, op2_vshape)): + op1_shape = list(op1_vshape[I]) + [np.prod(op1_vshape[J1]) + ] + [np.prod(op1_vshape[K])] + op2_shape = list(op2_vshape[I]) + [np.prod(op2_vshape[J2]) + ] + [np.prod(op2_vshape[K])] - step = reshape, [var1], var1, merged_shape1 + # Merge J dims and K dims by reshaping + step = reshape, [var1], var1, op1_shape plan.add_step(step) - step = reshape, [var2], var2, merged_shape2 + step = reshape, [var2], var2, op2_shape plan.add_step(step) # Matmul step = matmul, [var1, var2], var2, False, True plan.add_step(step) - # The result shape is in I..., J1, J2. Let's reshape back to known dimensions - # Note, this is static deduction, not by reading the tensor shape at runtime - result_shape = [1] * len(I) - for i, ax in enumerate(I): - result_shape[i] = max(op1_vshape[ax], op2_vshape[ax]) - if J1: - result_shape += J1_shape - if J2: - result_shape += J2_shape - - # Need a scalar dimension somehow - if result_shape: - step = reshape, [var2], var2, result_shape + # Reshape back + shape = list(vshape[I + J1 + J2]) + step = reshape, [var2], var2, shape plan.add_step(step) + elif j1 == j2 == k == 1: + # Can still do matmul even unknown shapes are present + step = matmul, [var1, var2], var2, False, True + plan.add_step(step) + + # In the rest cases we opt for ops other than matmul + else: + # unsqueeze operands include J1...J2... dimensions + if j2: + fill = list(range(i1 + j1, i1 + j1 + j2)) + step = unsqueeze, [var1], var1, fill + plan.add_step(step) + if j1: + fill = list(range(i2, i2 + j1)) + step = unsqueeze, [var2], var2, fill + plan.add_step(step) + # In case of no dimensions to contract, do an elementwise multiply + if k == 0: + # make broadcast + step = multiply, [var1, var2], var2 + plan.add_step(step) + # Contract and no join, turn into a dot + elif j1 + j2 == 0 and k == 1: + step = unsqueeze, [var1], var1, [-2] + plan.add_step(step) + step = unsqueeze, [var2], var2, [-1] + plan.add_step(step) + step = matmul, [var1, var2], var2 + plan.add_step(step) + step = squeeze, [var2], var2, [-1, -2] + plan.add_step(step) + elif j1 + j2 == 0 and not-1 in np.concatenate( + (op1_vshape[K], op2_vshape[K])): + assert all(op1_vshape[K] == op2_vshape[K]) + step = reshape, [var1], var1, list(op1_vshape[ + I]) + [1] + [np.prod(op1_vshape[K])] + plan.add_step(step) + step = reshape, [var2], var2, list(op2_vshape[ + I]) + [1] + [np.prod(op2_vshape[K])] + plan.add_step(step) + step = matmul, [var1, var2], var2, False, True + plan.add_step(step) + step = squeeze, [var2], var2, [-1, -2] + plan.add_step(step) + else: + step = multiply, [var1, var2], var2 + plan.add_step(step) + reduce_dims = list(range(-k, 0)) + plan_reduce(plan, op2, reduce_dims, keepdim=False) + # Wrap up, updating auxiliary data # Updating g_mask for I and J axes - for i, ax in enumerate(I + J1 + J2): - op2_mask[ax] = (result_shape[i] > 1) + for ax in I + J1 + J2: + op2_mask[ax] = vshape[ax] > 1 or vshape[ax] == -1 for ax in K: op2_mask[ax] = False @@ -514,6 +435,8 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K): for ax in I + J1 + J2: op2_view[ax], dim = dim, dim + 1 + g_view[op2] = list(op2_view) + def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count, n_bcast): @@ -737,7 +660,6 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast): return plan -@dygraph_only def einsum(equation, *operands): r""" einsum(equation, *operands) -- GitLab