From 5237cc0537ddb8ffa92104bc06920d7e51526bbe Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Wed, 17 Nov 2021 10:55:16 +0800 Subject: [PATCH] [Einsum] correct output dimension errors. (#37222) * [Einsum] correct output dimension errors due to single element tensors. * [Einsum] format polish. --- .../fluid/tests/unittests/test_einsum.py | 7 +++ python/paddle/tensor/einsum.py | 61 +++++++++++-------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_einsum.py b/python/paddle/fluid/tests/unittests/test_einsum.py index 39bf9b926b..13e763bee6 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum.py +++ b/python/paddle/fluid/tests/unittests/test_einsum.py @@ -91,6 +91,8 @@ class TestEinsum(unittest.TestCase): np.random.seed(12345) cls.TEST_SAMPLES = { + "a": np.random.rand(1, 1), + "b": np.random.rand(1), "x": np.random.rand(5), "y": np.random.rand(7), "A": np.random.rand(4, 5), @@ -179,6 +181,11 @@ class TestEinsumMatrixEleMul(TestEinsum): self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]} +class TestEinsumDegenerateMatrixVecMul(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ij,j", "data": ["a", "b"]} + + class TestEinsumMatrixVecMul(TestEinsum): def setUp(self): self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]} diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index b6b0a9b1e7..e5d947294d 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -394,11 +394,12 @@ def plan_reduce(plan, op, reduce_dims, keepdim): def plan_scalar_prod(plan, op1, op2): varnames = [f'op{op1}', f'op{op2}'] f = lambda var1, var2: paddle_sum(var1) * var2 + # f = lambda var1, var2: var1 * var2 step = f, varnames, varnames[1] plan.add_step(step) -def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): +def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K): ''' plan matmul ''' @@ -416,7 +417,7 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): K1_dims = [op1_view[ax] for ax in K] K2_dims = [op2_view[ax] for ax in K] - op1_mask, op2_mask = [g_op_masks[op] for op in (op1, op2)] + op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)] op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)] op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)] @@ -515,13 +516,13 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): op2_view[ax], dim = dim, dim + 1 -def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count, +def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count, n_bcast): ''' Plan various kinds of summation ''' op1_view, op2_view = g_view[op1], g_view[op2] - op1_mask, op2_mask = g_op_masks[op1], g_op_masks[op2] + op1_mask, op2_mask = g_supports[op1], g_supports[op2] ndim = len(op1_view) nout = ndim - len(g_count) @@ -553,7 +554,7 @@ def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count, # Now it's OK to merge the K dims as the same shape holds # print(f'I: {I} J1: {J1} J2: {J2} K: {K}') - plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K) + plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K) def rearrange(axes): @@ -625,7 +626,7 @@ class Plan: return res -def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): +def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast): ''' Plans the actual execution steps. Results @@ -646,17 +647,18 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): plan_broadcast(plan, operands, g_view) return plan - # Down count axis >= nout and degenerate dimensions (masked is not set) - for view, mask in zip(g_view, g_op_masks): + # Down count degenerate contraction dimensions. + for view, support in zip(g_view, g_supports): + # To collect the down count number, we use a type casting trick down_count = [ - 1 if (dim > -1 and not masked) else 0 - for dim, masked in zip(view[nout:], mask[nout:]) + int((d + 1) and (not s)) + for d, s in zip(view[nout:], support[nout:]) ] - for i, d in enumerate(down_count): - g_count[i] -= d + for i, count in enumerate(down_count): + g_count[i] -= count - # Reduce any dimension for which g_mask is set and g_count == 1 - for i, view, mask in zip(range(nop), g_view, g_op_masks): + # Reduce any dimension for which g_support is set and g_count == 1 + for i, view, mask in zip(range(nop), g_view, g_supports): to_reduce = [] for dim, masked, count in zip(view[nout:], mask[nout:], g_count): to_reduce.append(dim if (masked and count == 1) else -1) @@ -695,27 +697,36 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): # (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul # Resolve the summation kind: dot, matmul or * - if not any(g_op_masks[i - 1]): - # op1 is a scalar + if not any(g_supports[i - 1]): + # op1 is a one element tensor. plan_scalar_prod(plan, i - 1, i) else: - plan_summation(plan, g_view, i - 1, i, g_op_masks, g_shape, g_count, + plan_summation(plan, g_view, i - 1, i, g_supports, g_shape, g_count, n_bcast) # for ax, dim in enumerate(g_view[nop-1][:nout]): # assert dim == ax - assert all(not masked for masked in g_op_masks[nop - 1][nout:]) + assert all(not masked for masked in g_supports[nop - 1][nout:]) view = g_view[-1] if any(ax != dim for ax, dim in enumerate(view[:nout])): perm = [dim for dim in view if dim >= 0] - varname = f'op{nop-1}' - step = transpose, [varname], varname, perm - plan.add_step(step) + if sorted(perm) != perm: + varname = f'op{nop-1}' + step = transpose, [varname], varname, perm + plan.add_step(step) dim = 0 + unsqueeze_dims = [] for ax, d in enumerate(view): if d != -1: view[ax], dim = dim, dim + 1 + for ax, d in enumerate(view[:nout]): + if d == -1: + unsqueeze_dims.append(ax) + if unsqueeze_dims: + varname = f'op{nop-1}' + step = unsqueeze, [varname], varname, unsqueeze_dims + plan.add_step(step) squeeze_dims = [dim for dim in view[nout:] if dim != -1] if squeeze_dims: @@ -922,18 +933,18 @@ def einsum(equation, *operands): # should broadcast to # g_nout: # Number of output axes - # g_op_masks - # A list of masks that specify each operand's non-trivial dimensions + # g_supports + # Booleans indicating each operand's non-trivial dimensions # g_count # Counting how many non-trivial dimensions remain for each ax g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs, n_bcast_dims) - g_shape, g_op_masks = build_global_shape(g_view, g_labels, + g_shape, g_supports = build_global_shape(g_view, g_labels, [op.shape for op in operands]) # Now we're ready to build up an execution plan - args = operands, g_view, g_shape, g_op_masks, g_count, n_bcast_dims + args = operands, g_view, g_shape, g_supports, g_count, n_bcast_dims plan = plan_einsum(*args) result = plan.execute() -- GitLab