未验证 提交 187fcfa3 编写于 作者: T Tongxin Bai 提交者: GitHub

[einsum] refactored and supporting unknown shapes in static mode (#40360)

* formatted.

* Remove dead code.

* Fix error message in the unit test.

* polish formats.

* [Einsum] fix bugs.
上级 69dd43d1
...@@ -26,14 +26,14 @@ class TestErrors(unittest.TestCase): ...@@ -26,14 +26,14 @@ class TestErrors(unittest.TestCase):
def test_diagonalize_errors(self): def test_diagonalize_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a) a = paddle.to_tensor(a)
with self.assertRaisesRegex(AssertionError, ( with self.assertRaisesRegex(AssertionError,
'Diagonal and trace not implemented yet.')): ('Duplicate labels are not supported.')):
paddle.einsum('...ii->...i', a) paddle.einsum('...ii->...i', a)
with self.assertRaisesRegex(AssertionError, ( with self.assertRaisesRegex(AssertionError,
'Diagonal and trace not implemented yet.')): ('Duplicate labels are not supported.')):
paddle.einsum('i...i', a) paddle.einsum('i...i', a)
with self.assertRaisesRegex(AssertionError, ( with self.assertRaisesRegex(AssertionError,
'Diagonal and trace not implemented yet.')): ('Duplicate labels are not supported.')):
paddle.einsum('i...i->i...', a) paddle.einsum('i...i->i...', a)
def test_param_errors(self): def test_param_errors(self):
...@@ -396,6 +396,51 @@ class TestNumpyTests(unittest.TestCase): ...@@ -396,6 +396,51 @@ class TestNumpyTests(unittest.TestCase):
self.check_output('a...b,b...c,c...a', a, a, a) self.check_output('a...b,b...c,c...a', a, a, a)
self.check_output('...ab,...ba,...ab,...ab', a, a, a, a) self.check_output('...ab,...ba,...ab,...ab', a, a, a, a)
def test_static_graph(self):
paddle.enable_static()
fluid = paddle.fluid
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
a = paddle.static.data(
name='a', shape=[3, None, None, None], dtype='float')
b = paddle.static.data(
name='b', shape=[2, None, None, None], dtype='float')
c = paddle.static.data(
name='c', shape=[None, None, 2, None], dtype='float')
d = paddle.static.data(
name='d', shape=[None, None, 5], dtype='float')
e = paddle.static.data(
name='e', shape=[None, 2, None], dtype='float')
outs = []
outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b))
outs.append(paddle.einsum('...ik, ...j', c, d))
outs.append(paddle.einsum('...kj, ...ik', d, e))
outs.append(paddle.einsum('ijk..., ikj', c, e))
outs.append(paddle.einsum('ijk..., ikj->...ij', c, e))
exe = fluid.Executor(self.place)
exe.run(startup)
a = np.arange(72).reshape(3, 2, 3, 4).astype('float')
b = np.arange(48).reshape(2, 2, 3, 4).astype('float')
c = np.arange(48).reshape(2, 3, 2, 4).astype('float')
d = np.arange(30).reshape(2, 3, 5).astype('float')
e = np.arange(12).reshape(2, 2, 3).astype('float')
feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e}
actual = exe.run(main, feed=feeds, fetch_list=[outs])
expect = []
expect.append(np.einsum("ibnd,jbnd->bnij", a, b))
expect.append(np.einsum('...ik, ...j', c, d))
expect.append(np.einsum('...kj, ...ik', d, e))
expect.append(np.einsum('ijk..., ikj', c, e))
expect.append(np.einsum('ijk..., ikj->...ij', c, e))
for a, e in zip(actual, expect):
self.check_output_equal(a, e)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import numpy as np
import re import re
from .linalg import matmul, transpose from .linalg import dot, matmul, transpose
from .manipulation import squeeze, unsqueeze, reshape from .manipulation import squeeze, unsqueeze, reshape
from .math import multiply from .math import multiply
from .math import sum as paddle_sum from .math import sum as paddle_sum
...@@ -111,36 +112,6 @@ def validate_rhs(rhs, input_labels, n_bcast_dims): ...@@ -111,36 +112,6 @@ def validate_rhs(rhs, input_labels, n_bcast_dims):
f"Invalid equation: duplicate output labels are found.") f"Invalid equation: duplicate output labels are found.")
# '''
# Tests if the two operands can perform a broadcast operation on the given ranges of dimensions.
# We follow the Numpy broadcasting convention which states that, by lining up the shape arrays
# starting from the right most dimension, all the aligned dimensions either have equal sizes or
# one of them is sized one.
# Parameters
# ----------
# args:
# *args unpacks into operand one's axes range, shape, operand two's axes range, shape
# f:
# if available, is used as a callback for postprocessing the aligned operand dimensions.
# '''
# xran, xshape, yran, yshape = args
#
# xran_inv, yran_inv = xran[::-1], yran[::-1]
#
# for xi, yi in zip(xran_inv, yran_inv):
# xs, ys = xshape[xi], yshape[yi]
# cond = xs == ys or xs == 1 or ys == 1
# if not cond:
# return False
#
# if not f:
# return True
#
# # Apply the callback to each aligned dimension pair
# for xi, yi in zip(xran_inv, yran_inv):
# f(xi, yi)
def build_view(in_labels, out_labels): def build_view(in_labels, out_labels):
''' '''
Build an inverse map of dimension indices. Three conditions must hold for Build an inverse map of dimension indices. Three conditions must hold for
...@@ -291,39 +262,12 @@ def build_global_shape(g_view, g_labels, op_shapes): ...@@ -291,39 +262,12 @@ def build_global_shape(g_view, g_labels, op_shapes):
g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape] g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape]
g_masks = [[s > 1 for s in view_shape] for view_shape in view_shapes] g_masks = [[s > 1 or s == -1 for s in view_shape]
for view_shape in view_shapes]
return g_shape, g_masks return g_shape, g_masks
def dim_strides(shape):
'''
Returns the dimension strides for a tensor shape
'''
strides = []
stride = 1
for size in shape[::-1]:
strides.append(stride)
stride = stride * size
return strides
def create_view(operand, *view_def):
'''
Create and materialize a view.
Parameters
----------
operand:
the base tensor operand
view_def:
include two lists which define the view's dimension sizes and strides
'''
assert False, f'Diagonal and trace not implemented yet.'
view_shape, view_strides = view_def
return operand.create_view(view_shape, view_strides)
def has_duplicated_labels(labels): def has_duplicated_labels(labels):
''' '''
Returns True if there is any duplicate label. Returns True if there is any duplicate label.
...@@ -337,46 +281,17 @@ def diagonalize(labels, operand): ...@@ -337,46 +281,17 @@ def diagonalize(labels, operand):
Merges dimensions with duplicate labels. Merges dimensions with duplicate labels.
For those dimensions with duplicate labels, merge them into one dimension For those dimensions with duplicate labels, merge them into one dimension
which represents the diagonal elements. That requires the duplicate labeled which represents the diagonal elements. This requires the dimensions with
dimensions equal sized. The order of dimensions is kept unchanged up to duplicate labels are equal sized.
the left-most appearance of each label.
Examples Examples
-------- --------
'ijj...i' would be merged into 'ij...' 'ijj...i' would be merged into 'ij...'
''' '''
if not has_duplicated_labels(labels): assert not has_duplicated_labels(labels), (
return labels, operand f'Duplicate labels are not supported.')
strides = dim_strides(operand.shape)
shape = operand.shape
new_labels = []
new_shape = []
new_strides = []
for ax, l in enumerate(labels):
if l == '.' or l not in new_labels:
# not duplicate
new_labels.append(l)
new_strides.append(strides[ax])
new_shape.append(shape[ax])
else:
# duplicate label
diag_ax = new_labels.index(l)
new_strides[diag_ax] += strides[ax]
# Call framework API to build a new tensor return labels, operand
new_op = create_view(operand, new_shape, new_strides)
return new_labels, new_op
def prod(iter, default=1):
if len(iter):
res = 1
for s in iter:
res *= s
return res
return default
def plan_reduce(plan, op, reduce_dims, keepdim): def plan_reduce(plan, op, reduce_dims, keepdim):
...@@ -408,102 +323,108 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K): ...@@ -408,102 +323,108 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
op1_view, op2_view = [g_view[op] for op in (op1, op2)] op1_view, op2_view = [g_view[op] for op in (op1, op2)]
# Note, I may index into -1 I1 = [idx for idx in I if op1_view[idx] >= 0]
I1_dims = [op1_view[ax] for ax in I if op1_view[ax] >= 0] I2 = [idx for idx in I if op2_view[idx] >= 0]
I2_dims = [op2_view[ax] for ax in I if op2_view[ax] >= 0] op1_view = np.array(op1_view)
J1_dims = [op1_view[ax] for ax in J1] op1_dims = op1_view[I1 + J1 + K]
J2_dims = [op2_view[ax] for ax in J2]
K1_dims = [op1_view[ax] for ax in K]
K2_dims = [op2_view[ax] for ax in K]
op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)] op2_view = np.array(op2_view)
op1_vshape = [s if m else 1 for s, m in zip(g_shape, op1_mask)] op2_dims = op2_view[I2 + J2 + K]
op2_vshape = [s if m else 1 for s, m in zip(g_shape, op2_mask)]
I1_shape, J1_shape, K1_shape = [[op1_vshape[ax] for ax in axes]
for axes in (I, J1, K)]
I2_shape, J2_shape, K2_shape = [[op2_vshape[ax] for ax in axes]
for axes in (I, J2, K)]
K1_size, J1_size, J2_size = prod(K1_shape), prod(J1_shape), prod(J2_shape) op1_mask, op2_mask = [g_supports[op] for op in (op1, op2)]
op1_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op1_mask)])
op2_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op2_mask)])
vshape = np.maximum(op1_vshape, op2_vshape)
perm1 = I1_dims + J1_dims + K1_dims i1, i2, j1, j2, k = map(len, (I1, I2, J1, J2, K))
perm2 = I2_dims + J2_dims + K2_dims
if any(i != dim for i, dim in enumerate(perm1)): if any(op1_dims != np.arange(len(op1_dims))):
# print(f'perm1: {perm1}') # print(f'perm1: {perm1}')
step = transpose, [var1], var1, perm1 step = transpose, [var1], var1, list(op1_dims)
plan.add_step(step) plan.add_step(step)
if any(i != dim for i, dim in enumerate(perm2)): if any(op2_dims != np.arange(len(op2_dims))):
# print(f'perm2: {perm2}') # print(f'perm2: {perm2}')
step = transpose, [var2], var2, perm2 step = transpose, [var2], var2, list(op2_dims)
plan.add_step(step) plan.add_step(step)
# In case of no K... dimensions, do a broadcast # Check if conditions hold for turnning the operation into a matmul
if not K: if j1 + j2 > 0 and k > 0 and -1 not in np.concatenate(
(op1_vshape, op2_vshape)):
op1_shape = list(op1_vshape[I]) + [np.prod(op1_vshape[J1])
] + [np.prod(op1_vshape[K])]
op2_shape = list(op2_vshape[I]) + [np.prod(op2_vshape[J2])
] + [np.prod(op2_vshape[K])]
# Merge J dims and K dims by reshaping
step = reshape, [var1], var1, op1_shape
plan.add_step(step)
step = reshape, [var2], var2, op2_shape
plan.add_step(step)
# Matmul
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# Reshape back
shape = list(vshape[I + J1 + J2])
step = reshape, [var2], var2, shape
plan.add_step(step)
elif j1 == j2 == k == 1:
# Can still do matmul even unknown shapes are present
step = matmul, [var1, var2], var2, False, True
plan.add_step(step)
# In the rest cases we opt for ops other than matmul
else:
# unsqueeze operands include J1...J2... dimensions # unsqueeze operands include J1...J2... dimensions
if J2: if j2:
fill_start = len(I2_dims) + len(J1) fill = list(range(i1 + j1, i1 + j1 + j2))
fill_end = fill_start + len(J2)
fill = list(range(fill_start, fill_end))
step = unsqueeze, [var1], var1, fill step = unsqueeze, [var1], var1, fill
plan.add_step(step) plan.add_step(step)
if J1: if j1:
fill_start = len(I2_dims) fill = list(range(i2, i2 + j1))
fill_end = fill_start + len(J1)
fill = list(range(fill_start, fill_end))
step = unsqueeze, [var2], var2, fill step = unsqueeze, [var2], var2, fill
plan.add_step(step) plan.add_step(step)
# In case of no dimensions to contract, do an elementwise multiply
if k == 0:
# make broadcast # make broadcast
step = multiply, [var1, var2], var2 step = multiply, [var1, var2], var2
plan.add_step(step) plan.add_step(step)
# K... are there, let's reason about I... and J... # Contract and no join, turn into a dot
# In case I... and J... are empty, do the vector-vector version of matmul elif j1 + j2 == 0 and k == 1:
elif not I and not J1 and not J2: step = unsqueeze, [var1], var1, [-2]
# merge K dimensions plan.add_step(step)
if len(K) > 1: step = unsqueeze, [var2], var2, [-1]
for var in var1, var2:
step = reshape, [var], var, [K1_size]
plan.add_step(step) plan.add_step(step)
# Build vector-vector matmul
step = matmul, [var1, var2], var2 step = matmul, [var1, var2], var2
plan.add_step(step) plan.add_step(step)
# General case, there are K... and some I... and J..., the actual operation will be step = squeeze, [var2], var2, [-1, -2]
# matrix-vector or matrix-matrix multiplies, depending on the operands' shapes.
else:
# Merge J dims and K dims by reshaping
merged_shape1 = I1_shape + [J1_size] + [K1_size]
merged_shape2 = I2_shape + [J2_size] + [K1_size]
step = reshape, [var1], var1, merged_shape1
plan.add_step(step) plan.add_step(step)
step = reshape, [var2], var2, merged_shape2 elif j1 + j2 == 0 and not-1 in np.concatenate(
(op1_vshape[K], op2_vshape[K])):
assert all(op1_vshape[K] == op2_vshape[K])
step = reshape, [var1], var1, list(op1_vshape[
I]) + [1] + [np.prod(op1_vshape[K])]
plan.add_step(step)
step = reshape, [var2], var2, list(op2_vshape[
I]) + [1] + [np.prod(op2_vshape[K])]
plan.add_step(step) plan.add_step(step)
# Matmul
step = matmul, [var1, var2], var2, False, True step = matmul, [var1, var2], var2, False, True
plan.add_step(step) plan.add_step(step)
step = squeeze, [var2], var2, [-1, -2]
# The result shape is in I..., J1, J2. Let's reshape back to known dimensions
# Note, this is static deduction, not by reading the tensor shape at runtime
result_shape = [1] * len(I)
for i, ax in enumerate(I):
result_shape[i] = max(op1_vshape[ax], op2_vshape[ax])
if J1:
result_shape += J1_shape
if J2:
result_shape += J2_shape
# Need a scalar dimension somehow
if result_shape:
step = reshape, [var2], var2, result_shape
plan.add_step(step) plan.add_step(step)
else:
step = multiply, [var1, var2], var2
plan.add_step(step)
reduce_dims = list(range(-k, 0))
plan_reduce(plan, op2, reduce_dims, keepdim=False)
# Wrap up, updating auxiliary data # Wrap up, updating auxiliary data
# Updating g_mask for I and J axes # Updating g_mask for I and J axes
for i, ax in enumerate(I + J1 + J2): for ax in I + J1 + J2:
op2_mask[ax] = (result_shape[i] > 1) op2_mask[ax] = vshape[ax] > 1 or vshape[ax] == -1
for ax in K: for ax in K:
op2_mask[ax] = False op2_mask[ax] = False
...@@ -514,6 +435,8 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K): ...@@ -514,6 +435,8 @@ def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
for ax in I + J1 + J2: for ax in I + J1 + J2:
op2_view[ax], dim = dim, dim + 1 op2_view[ax], dim = dim, dim + 1
g_view[op2] = list(op2_view)
def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count, def plan_summation(plan, g_view, op1, op2, g_supports, g_shape, g_count,
n_bcast): n_bcast):
...@@ -737,7 +660,6 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast): ...@@ -737,7 +660,6 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
return plan return plan
@dygraph_only
def einsum(equation, *operands): def einsum(equation, *operands):
r""" r"""
einsum(equation, *operands) einsum(equation, *operands)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册