未验证 提交 5237cc05 编写于 作者: T Tongxin Bai 提交者: GitHub

[Einsum] correct output dimension errors. (#37222)

* [Einsum] correct output dimension errors due to single element tensors.

* [Einsum] format polish.
上级 d943459b
...@@ -91,6 +91,8 @@ class TestEinsum(unittest.TestCase): ...@@ -91,6 +91,8 @@ class TestEinsum(unittest.TestCase):
np.random.seed(12345) np.random.seed(12345)
cls.TEST_SAMPLES = { cls.TEST_SAMPLES = {
"a": np.random.rand(1, 1),
"b": np.random.rand(1),
"x": np.random.rand(5), "x": np.random.rand(5),
"y": np.random.rand(7), "y": np.random.rand(7),
"A": np.random.rand(4, 5), "A": np.random.rand(4, 5),
...@@ -179,6 +181,11 @@ class TestEinsumMatrixEleMul(TestEinsum): ...@@ -179,6 +181,11 @@ class TestEinsumMatrixEleMul(TestEinsum):
self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]} self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]}
class TestEinsumDegenerateMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j", "data": ["a", "b"]}
class TestEinsumMatrixVecMul(TestEinsum): class TestEinsumMatrixVecMul(TestEinsum):
def setUp(self): def setUp(self):
self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]} self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]}
......
...@@ -394,11 +394,12 @@ def plan_reduce(plan, op, reduce_dims, keepdim): ...@@ -394,11 +394,12 @@ def plan_reduce(plan, op, reduce_dims, keepdim):
def plan_scalar_prod(plan, op1, op2): def plan_scalar_prod(plan, op1, op2):
varnames = [f'op{op1}', f'op{op2}'] varnames = [f'op{op1}', f'op{op2}']
f = lambda var1, var2: paddle_sum(var1) * var2 f = lambda var1, var2: paddle_sum(var1) * var2
# f = lambda var1, var2: var1 * var2
step = f, varnames, varnames[1] step = f, varnames, varnames[1]
plan.add_step(step) plan.add_step(step)
def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
''' '''
plan matmul plan matmul
''' '''
...@@ -416,7 +417,7 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): ...@@ -416,7 +417,7 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
K1_dims = [op1_view[ax] for ax in K] K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K] K2_dims = [op2_view[ax] for ax in K]
op1_mask, op2_mask = [g_op_masks[op] for op in (op1, op2)] op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)] op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)] op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]
...@@ -515,13 +516,13 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K): ...@@ -515,13 +516,13 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
op2_view[ax], dim = dim, dim + 1 op2_view[ax], dim = dim, dim + 1
def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count, def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count,
n_bcast): n_bcast):
''' '''
Plan various kinds of summation Plan various kinds of summation
''' '''
op1_view, op2_view = g_view[op1], g_view[op2] op1_view, op2_view = g_view[op1], g_view[op2]
op1_mask, op2_mask = g_op_masks[op1], g_op_masks[op2] op1_mask, op2_mask = g_supports[op1], g_supports[op2]
ndim = len(op1_view) ndim = len(op1_view)
nout = ndim - len(g_count) nout = ndim - len(g_count)
...@@ -553,7 +554,7 @@ def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count, ...@@ -553,7 +554,7 @@ def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,
# Now it's OK to merge the K dims as the same shape holds # Now it's OK to merge the K dims as the same shape holds
# print(f'I: {I} J1: {J1} J2: {J2} K: {K}') # print(f'I: {I} J1: {J1} J2: {J2} K: {K}')
plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K) plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K)
def rearrange(axes): def rearrange(axes):
...@@ -625,7 +626,7 @@ class Plan: ...@@ -625,7 +626,7 @@ class Plan:
return res return res
def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
''' '''
Plans the actual execution steps. Plans the actual execution steps.
Results Results
...@@ -646,17 +647,18 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): ...@@ -646,17 +647,18 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
plan_broadcast(plan, operands, g_view) plan_broadcast(plan, operands, g_view)
return plan return plan
# Down count axis >= nout and degenerate dimensions (masked is not set) # Down count degenerate contraction dimensions.
for view, mask in zip(g_view, g_op_masks): for view, support in zip(g_view, g_supports):
# To collect the down count number, we use a type casting trick
down_count = [ down_count = [
1 if (dim > -1 and not masked) else 0 int((d + 1) and (not s))
for dim, masked in zip(view[nout:], mask[nout:]) for d, s in zip(view[nout:], support[nout:])
] ]
for i, d in enumerate(down_count): for i, count in enumerate(down_count):
g_count[i] -= d g_count[i] -= count
# Reduce any dimension for which g_mask is set and g_count == 1 # Reduce any dimension for which g_support is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_op_masks): for i, view, mask in zip(range(nop), g_view, g_supports):
to_reduce = [] to_reduce = []
for dim, masked, count in zip(view[nout:], mask[nout:], g_count): for dim, masked, count in zip(view[nout:], mask[nout:], g_count):
to_reduce.append(dim if (masked and count == 1) else -1) to_reduce.append(dim if (masked and count == 1) else -1)
...@@ -695,27 +697,36 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast): ...@@ -695,27 +697,36 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
# (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul # (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul
# Resolve the summation kind: dot, matmul or * # Resolve the summation kind: dot, matmul or *
if not any(g_op_masks[i - 1]): if not any(g_supports[i - 1]):
# op1 is a scalar # op1 is a one element tensor.
plan_scalar_prod(plan, i - 1, i) plan_scalar_prod(plan, i - 1, i)
else: else:
plan_summation(plan, g_view, i - 1, i, g_op_masks, g_shape, g_count, plan_summation(plan, g_view, i - 1, i, g_supports, g_shape, g_count,
n_bcast) n_bcast)
# for ax, dim in enumerate(g_view[nop-1][:nout]): # for ax, dim in enumerate(g_view[nop-1][:nout]):
# assert dim == ax # assert dim == ax
assert all(not masked for masked in g_op_masks[nop - 1][nout:]) assert all(not masked for masked in g_supports[nop - 1][nout:])
view = g_view[-1] view = g_view[-1]
if any(ax != dim for ax, dim in enumerate(view[:nout])): if any(ax != dim for ax, dim in enumerate(view[:nout])):
perm = [dim for dim in view if dim >= 0] perm = [dim for dim in view if dim >= 0]
varname = f'op{nop-1}' if sorted(perm) != perm:
step = transpose, [varname], varname, perm varname = f'op{nop-1}'
plan.add_step(step) step = transpose, [varname], varname, perm
plan.add_step(step)
dim = 0 dim = 0
unsqueeze_dims = []
for ax, d in enumerate(view): for ax, d in enumerate(view):
if d != -1: if d != -1:
view[ax], dim = dim, dim + 1 view[ax], dim = dim, dim + 1
for ax, d in enumerate(view[:nout]):
if d == -1:
unsqueeze_dims.append(ax)
if unsqueeze_dims:
varname = f'op{nop-1}'
step = unsqueeze, [varname], varname, unsqueeze_dims
plan.add_step(step)
squeeze_dims = [dim for dim in view[nout:] if dim != -1] squeeze_dims = [dim for dim in view[nout:] if dim != -1]
if squeeze_dims: if squeeze_dims:
...@@ -922,18 +933,18 @@ def einsum(equation, *operands): ...@@ -922,18 +933,18 @@ def einsum(equation, *operands):
# should broadcast to # should broadcast to
# g_nout: # g_nout:
# Number of output axes # Number of output axes
# g_op_masks # g_supports
# A list of masks that specify each operand's non-trivial dimensions # Booleans indicating each operand's non-trivial dimensions
# g_count # g_count
# Counting how many non-trivial dimensions remain for each ax # Counting how many non-trivial dimensions remain for each ax
g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs, g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs,
n_bcast_dims) n_bcast_dims)
g_shape, g_op_masks = build_global_shape(g_view, g_labels, g_shape, g_supports = build_global_shape(g_view, g_labels,
[op.shape for op in operands]) [op.shape for op in operands])
# Now we're ready to build up an execution plan # Now we're ready to build up an execution plan
args = operands, g_view, g_shape, g_op_masks, g_count, n_bcast_dims args = operands, g_view, g_shape, g_supports, g_count, n_bcast_dims
plan = plan_einsum(*args) plan = plan_einsum(*args)
result = plan.execute() result = plan.execute()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册