未验证 提交 187fcfa3 编写于 作者: T Tongxin Bai 提交者: GitHub

[einsum] refactored and supporting unknown shapes in static mode (#40360)

* formatted.

* Remove dead code.

* Fix error message in the unit test.

* polish formats.

* [Einsum] fix bugs.
上级 69dd43d1
......@@ -26,14 +26,14 @@ class TestErrors(unittest.TestCase):
def test_diagonalize_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('...ii->...i', a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('i...i', a)
with self.assertRaisesRegex(AssertionError, (
'Diagonal and trace not implemented yet.')):
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('i...i->i...', a)
def test_param_errors(self):
......@@ -396,6 +396,51 @@ class TestNumpyTests(unittest.TestCase):
self.check_output('a...b,b...c,c...a', a, a, a)
self.check_output('...ab,...ba,...ab,...ab', a, a, a, a)
def test_static_graph(self):
paddle.enable_static()
fluid = paddle.fluid
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
a = paddle.static.data(
name='a', shape=[3, None, None, None], dtype='float')
b = paddle.static.data(
name='b', shape=[2, None, None, None], dtype='float')
c = paddle.static.data(
name='c', shape=[None, None, 2, None], dtype='float')
d = paddle.static.data(
name='d', shape=[None, None, 5], dtype='float')
e = paddle.static.data(
name='e', shape=[None, 2, None], dtype='float')
outs = []
outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b))
outs.append(paddle.einsum('...ik, ...j', c, d))
outs.append(paddle.einsum('...kj, ...ik', d, e))
outs.append(paddle.einsum('ijk..., ikj', c, e))
outs.append(paddle.einsum('ijk..., ikj->...ij', c, e))
exe = fluid.Executor(self.place)
exe.run(startup)
a = np.arange(72).reshape(3, 2, 3, 4).astype('float')
b = np.arange(48).reshape(2, 2, 3, 4).astype('float')
c = np.arange(48).reshape(2, 3, 2, 4).astype('float')
d = np.arange(30).reshape(2, 3, 5).astype('float')
e = np.arange(12).reshape(2, 2, 3).astype('float')
feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e}
actual = exe.run(main, feed=feeds, fetch_list=[outs])
expect = []
expect.append(np.einsum("ibnd,jbnd->bnij", a, b))
expect.append(np.einsum('...ik, ...j', c, d))
expect.append(np.einsum('...kj, ...ik', d, e))
expect.append(np.einsum('ijk..., ikj', c, e))
expect.append(np.einsum('ijk..., ikj->...ij', c, e))
for a, e in zip(actual, expect):
self.check_output_equal(a, e)
if __name__ == "__main__":
unittest.main()
......@@ -13,9 +13,10 @@
# limitations under the License.
import itertools
import numpy as np
import re
from .linalg import matmul, transpose
from .linalg import dot, matmul, transpose
from .manipulation import squeeze, unsqueeze, reshape
from .math import multiply
from .math import sum as paddle_sum
......@@ -111,36 +112,6 @@ def validate_rhs(rhs, input_labels, n_bcast_dims):
f"Invalid equation: duplicate output labels are found.")
# '''
# Tests if the two operands can perform a broadcast operation on the given ranges of dimensions.
# We follow the Numpy broadcasting convention which states that, by lining up the shape arrays
# starting from the right most dimension, all the aligned dimensions either have equal sizes or
# one of them is sized one.
# Parameters
# ----------
# args:
# *args unpacks into operand one's axes range, shape, operand two's axes range, shape
# f:
# if available, is used as a callback for postprocessing the aligned operand dimensions.
# '''
# xran, xshape, yran, yshape = args
#
# xran_inv, yran_inv = xran[::-1], yran[::-1]
#
# for xi, yi in zip(xran_inv, yran_inv):
# xs, ys = xshape[xi], yshape[yi]
# cond = xs == ys or xs == 1 or ys == 1
# if not cond:
# return False
#
# if not f:
# return True
#
# # Apply the callback to each aligned dimension pair
# for xi, yi in zip(xran_inv, yran_inv):
# f(xi, yi)
def build_view(in_labels, out_labels):
'''
Build an inverse map of dimension indices. Three conditions must hold for
......@@ -291,39 +262,12 @@ def build_global_shape(g_view, g_labels, op_shapes):
g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape]
g_masks = [[s > 1 for s in view_shape] for view_shape in view_shapes]
g_masks = [[s > 1 or s == -1 for s in view_shape]
for view_shape in view_shapes]
return g_shape, g_masks
def dim_strides(shape):
'''
Returns the dimension strides for a tensor shape
'''
strides = []
stride = 1
for size in shape[::-1]:
strides.append(stride)
stride = stride * size
return strides
def create_view(operand, *view_def):
'''
Create and materialize a view.
Parameters
----------
operand:
the base tensor operand
view_def:
include two lists which define the view's dimension sizes and strides
'''
assert False, f'Diagonal and trace not implemented yet.'
view_shape, view_strides = view_def
return operand.create_view(view_shape, view_strides)
def has_duplicated_labels(labels):
'''
Returns True if there is any duplicate label.
......@@ -337,46 +281,17 @@ def diagonalize(labels, operand):
Merges dimensions with duplicate labels.
For those dimensions with duplicate labels, merge them into one dimension
which represents the diagonal elements. That requires the duplicate labeled
dimensions equal sized. The order of dimensions is kept unchanged up to
the left-most appearance of each label.
which represents the diagonal elements. This requires the dimensions with
duplicate labels are equal sized.
Examples
--------
'ijj...i' would be merged into 'ij...'
'''
if not has_duplicated_labels(labels):
return labels, operand
strides = dim_strides(operand.shape)
shape = operand.shape
new_labels = []
new_shape = []
new_strides = []
for ax, l in enumerate(labels):
if l == '.' or l not in new_labels:
# not duplicate
new_labels.append(l)
new_strides.append(strides[ax])
new_shape.append(shape[ax])
else:
# duplicate label
diag_ax = new_labels.index(l)
new_strides[diag_ax] += strides[ax]
assert not has_duplicated_labels(labels), (
f'Duplicate labels are not supported.')
# Call framework API to build a new tensor
new_op = create_view(operand, new_shape, new_strides)
return new_labels, new_op
def prod(iter, default=1):
if len(iter):
res = 1
for s in iter:
res *= s
return res
return default
return labels, operand
def plan_reduce(plan, op, reduce_dims, keepdim):
......@@ -408,102 +323,108 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
op1_view, op2_view = [g_view[op] for op in (op1, op2)]
# Note, I may index into -1
I1_dims = [op1_view[ax] for ax in I if op1_view[ax] >= 0]
I2_dims = [op2_view[ax] for ax in I if op2_view[ax] >= 0]
J1_dims = [op1_view[ax] for ax in J1]
J2_dims = [op2_view[ax] for ax in J2]
K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K]
I1 = [idx for idx in I if op1_view[idx] >= 0]
I2 = [idx for idx in I if op2_view[idx] >= 0]
op1_view = np.array(op1_view)
op1_dims = op1_view[I1 + J1 + K]
op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]
I1_shape, J1_shape, K1_shape = [[op1_vshape[ax] for ax in axes]
for axes in (I, J1, K)]
I2_shape, J2_shape, K2_shape = [[op2_vshape[ax] for ax in axes]
for axes in (I, J2, K)]
op2_view = np.array(op2_view)
op2_dims = op2_view[I2 + J2 + K]
K1_size, J1_size, J2_size = prod(K1_shape), prod(J1_shape), prod(J2_shape)
op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op1_mask)])
op2_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op2_mask)])
vshape = np.maximum(op1_vshape, op2_vshape)
perm1 = I1_dims + J1_dims + K1_dims
perm2 = I2_dims + J2_dims + K2_dims
i1, i2, j1, j2, k = map(len, (I1, I2, J1, J2, K))
if any(i != dim for i, dim in enumerate(perm1)):
if any(op1_dims != np.arange(len(op1_dims))):
# print(f'perm1: {perm1}')
step = transpose, [var1], var1, perm1
step = transpose, [var1], var1, list(op1_dims)
plan.add_step(step)
if any(i != dim for i, dim in enumerate(perm2)):
if any(op2_dims != np.arange(len(op2_dims))):
# print(f'perm2: {perm2}')
step = transpose, [var2], var2, perm2
step = transpose, [var2], var2, list(op2_dims)
plan.add_step(step)
# In case of no K... dimensions, do a broadcast
if not K:
# Check if conditions hold for turnning the operation into a matmul
if j1 + j2 > 0 and k > 0 and -1 not in np.concatenate(
(op1_vshape, op2_vshape)):
op1_shape = list(op1_vshape[I]) + [np.prod(op1_vshape[J1])
] + [np.prod(op1_vshape[K])]
op2_shape = list(op2_vshape[I]) + [np.prod(op2_vshape[J2])
] + [np.prod(op2_vshape[K])]
# Merge J dims and K dims by reshaping
step = reshape, [var1], var1, op1_shape
plan.add_step(step)
step = reshape, [var2], var2, op2_shape
plan.add_step(step)
# Matmul
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# Reshape back
shape = list(vshape[I + J1 + J2])
step = reshape, [var2], var2, shape
plan.add_step(step)
elif j1 == j2 == k == 1:
# Can still do matmul even unknown shapes are present
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# In the rest cases we opt for ops other than matmul
else:
# unsqueeze operands include J1...J2... dimensions
if J2:
fill_start = len(I2_dims) + len(J1)
fill_end = fill_start + len(J2)
fill = list(range(fill_start, fill_end))
if j2:
fill = list(range(i1 + j1, i1 + j1 + j2))
step = unsqueeze, [var1], var1, fill
plan.add_step(step)
if J1:
fill_start = len(I2_dims)
fill_end = fill_start + len(J1)
fill = list(range(fill_start, fill_end))
if j1:
fill = list(range(i2, i2 + j1))
step = unsqueeze, [var2], var2, fill
plan.add_step(step)
# In case of no dimensions to contract, do an elementwise multiply
if k == 0:
# make broadcast
step = multiply, [var1, var2], var2
plan.add_step(step)
# K... are there, let's reason about I... and J...
# In case I... and J... are empty, do the vector-vector version of matmul
elif not I and not J1 and not J2:
# merge K dimensions
if len(K) > 1:
for var in var1, var2:
step = reshape, [var], var, [K1_size]
# Contract and no join, turn into a dot
elif j1 + j2 == 0 and k == 1:
step = unsqueeze, [var1], var1, [-2]
plan.add_step(step)
step = unsqueeze, [var2], var2, [-1]
plan.add_step(step)
# Build vector-vector matmul
step = matmul, [var1, var2], var2
plan.add_step(step)
# General case, there are K... and some I... and J..., the actual operation will be
# matrix-vector or matrix-matrix multiplies, depending on the operands' shapes.
else:
# Merge J dims and K dims by reshaping
merged_shape1 = I1_shape + [J1_size] + [K1_size]
merged_shape2 = I2_shape + [J2_size] + [K1_size]
step = reshape, [var1], var1, merged_shape1
step = squeeze, [var2], var2, [-1, -2]
plan.add_step(step)
step = reshape, [var2], var2, merged_shape2
elif j1 + j2 == 0 and not-1 in np.concatenate(
(op1_vshape[K], op2_vshape[K])):
assert all(op1_vshape[K] == op2_vshape[K])
step = reshape, [var1], var1, list(op1_vshape[
I]) + [1] + [np.prod(op1_vshape[K])]
plan.add_step(step)
step = reshape, [var2], var2, list(op2_vshape[
I]) + [1] + [np.prod(op2_vshape[K])]
plan.add_step(step)
# Matmul
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# The result shape is in I..., J1, J2. Let's reshape back to known dimensions
# Note, this is static deduction, not by reading the tensor shape at runtime
result_shape = [1] * len(I)
for i, ax in enumerate(I):
result_shape[i] = max(op1_vshape[ax], op2_vshape[ax])
if J1:
result_shape += J1_shape
if J2:
result_shape += J2_shape
# Need a scalar dimension somehow
if result_shape:
step = reshape, [var2], var2, result_shape
step = squeeze, [var2], var2, [-1, -2]
plan.add_step(step)
else:
step = multiply, [var1, var2], var2
plan.add_step(step)
reduce_dims = list(range(-k, 0))
plan_reduce(plan, op2, reduce_dims, keepdim=False)
# Wrap up, updating auxiliary data
# Updating g_mask for I and J axes
for i, ax in enumerate(I + J1 + J2):
op2_mask[ax] = (result_shape[i] > 1)
for ax in I + J1 + J2:
op2_mask[ax] = vshape[ax] > 1 or vshape[ax] == -1
for ax in K:
op2_mask[ax] = False
......@@ -514,6 +435,8 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
for ax in I + J1 + J2:
op2_view[ax], dim = dim, dim + 1
g_view[op2] = list(op2_view)
def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count,
n_bcast):
......@@ -737,7 +660,6 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
return plan
@dygraph_only
def einsum(equation, *operands):
r"""
einsum(equation, *operands)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册