未验证 提交 5237cc05 编写于 作者: T Tongxin Bai 提交者: GitHub

[Einsum] correct output dimension errors. (#37222)

* [Einsum] correct output dimension errors due to single element tensors.

* [Einsum] format polish.
上级 d943459b
......@@ -91,6 +91,8 @@ class TestEinsum(unittest.TestCase):
np.random.seed(12345)
cls.TEST_SAMPLES = {
"a": np.random.rand(1, 1),
"b": np.random.rand(1),
"x": np.random.rand(5),
"y": np.random.rand(7),
"A": np.random.rand(4, 5),
......@@ -179,6 +181,11 @@ class TestEinsumMatrixEleMul(TestEinsum):
self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]}
class TestEinsumDegenerateMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j", "data": ["a", "b"]}
class TestEinsumMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]}
......
......@@ -394,11 +394,12 @@ def plan_reduce(plan, op, reduce_dims, keepdim):
def plan_scalar_prod(plan, op1, op2):
varnames = [f'op{op1}', f'op{op2}']
f = lambda var1, var2: paddle_sum(var1) * var2
# f = lambda var1, var2: var1 * var2
step = f, varnames, varnames[1]
plan.add_step(step)
def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
'''
plan matmul
'''
......@@ -416,7 +417,7 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K]
op1_mask, op2_mask = [g_op_masks[op] for op in (op1, op2)]
op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]
......@@ -515,13 +516,13 @@ def plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K):
op2_view[ax], dim = dim, dim + 1
def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,
def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count,
n_bcast):
'''
Plan various kinds of summation
'''
op1_view, op2_view = g_view[op1], g_view[op2]
op1_mask, op2_mask = g_op_masks[op1], g_op_masks[op2]
op1_mask, op2_mask = g_supports[op1], g_supports[op2]
ndim = len(op1_view)
nout = ndim - len(g_count)
......@@ -553,7 +554,7 @@ def plan_summation(plan, g_view, op1, op2, g_op_masks, g_shape, g_count,
# Now it's OK to merge the K dims as the same shape holds
# print(f'I: {I} J1: {J1} J2: {J2} K: {K}')
plan_matmul(plan, g_view, op1, op2, g_op_masks, g_shape, I, J1, J2, K)
plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K)
def rearrange(axes):
......@@ -625,7 +626,7 @@ class Plan:
return res
def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
'''
Plans the actual execution steps.
Results
......@@ -646,17 +647,18 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
plan_broadcast(plan, operands, g_view)
return plan
# Down count axis >= nout and degenerate dimensions (masked is not set)
for view, mask in zip(g_view, g_op_masks):
# Down count degenerate contraction dimensions.
for view, support in zip(g_view, g_supports):
# To collect the down count number, we use a type casting trick
down_count = [
1 if (dim > -1 and not masked) else 0
for dim, masked in zip(view[nout:], mask[nout:])
int((d + 1) and (not s))
for d, s in zip(view[nout:], support[nout:])
]
for i, d in enumerate(down_count):
g_count[i] -= d
for i, count in enumerate(down_count):
g_count[i] -= count
# Reduce any dimension for which g_mask is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_op_masks):
# Reduce any dimension for which g_support is set and g_count == 1
for i, view, mask in zip(range(nop), g_view, g_supports):
to_reduce = []
for dim, masked, count in zip(view[nout:], mask[nout:], g_count):
to_reduce.append(dim if (masked and count == 1) else -1)
......@@ -695,27 +697,36 @@ def plan_einsum(operands, g_view, g_shape, g_op_masks, g_count, n_bcast):
# (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul
# Resolve the summation kind: dot, matmul or *
if not any(g_op_masks[i - 1]):
# op1 is a scalar
if not any(g_supports[i - 1]):
# op1 is a one element tensor.
plan_scalar_prod(plan, i - 1, i)
else:
plan_summation(plan, g_view, i - 1, i, g_op_masks, g_shape, g_count,
plan_summation(plan, g_view, i - 1, i, g_supports, g_shape, g_count,
n_bcast)
# for ax, dim in enumerate(g_view[nop-1][:nout]):
# assert dim == ax
assert all(not masked for masked in g_op_masks[nop - 1][nout:])
assert all(not masked for masked in g_supports[nop - 1][nout:])
view = g_view[-1]
if any(ax != dim for ax, dim in enumerate(view[:nout])):
perm = [dim for dim in view if dim >= 0]
varname = f'op{nop-1}'
step = transpose, [varname], varname, perm
plan.add_step(step)
if sorted(perm) != perm:
varname = f'op{nop-1}'
step = transpose, [varname], varname, perm
plan.add_step(step)
dim = 0
unsqueeze_dims = []
for ax, d in enumerate(view):
if d != -1:
view[ax], dim = dim, dim + 1
for ax, d in enumerate(view[:nout]):
if d == -1:
unsqueeze_dims.append(ax)
if unsqueeze_dims:
varname = f'op{nop-1}'
step = unsqueeze, [varname], varname, unsqueeze_dims
plan.add_step(step)
squeeze_dims = [dim for dim in view[nout:] if dim != -1]
if squeeze_dims:
......@@ -922,18 +933,18 @@ def einsum(equation, *operands):
# should broadcast to
# g_nout:
# Number of output axes
# g_op_masks
# A list of masks that specify each operand's non-trivial dimensions
# g_supports
# Booleans indicating each operand's non-trivial dimensions
# g_count
# Counting how many non-trivial dimensions remain for each ax
g_labels, g_view, g_nout, g_count = build_global_view(nop_labels, rhs,
n_bcast_dims)
g_shape, g_op_masks = build_global_shape(g_view, g_labels,
g_shape, g_supports = build_global_shape(g_view, g_labels,
[op.shape for op in operands])
# Now we're ready to build up an execution plan
args = operands, g_view, g_shape, g_op_masks, g_count, n_bcast_dims
args = operands, g_view, g_shape, g_supports, g_count, n_bcast_dims
plan = plan_einsum(*args)
result = plan.execute()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册