未验证 提交 2bfee7d3 编写于 作者: F From00 提交者: GitHub

[cherry-pick] Add new API 'tensordot' (#36273) (#36454)

* Add new API tensordot
cherry-pick #36273
上级 8c0bacd4
......@@ -152,6 +152,7 @@ from .tensor.manipulation import unbind # noqa: F401
from .tensor.manipulation import roll # noqa: F401
from .tensor.manipulation import chunk # noqa: F401
from .tensor.manipulation import tolist # noqa: F401
from .tensor.manipulation import tensordot # noqa: F401
from .tensor.math import abs # noqa: F401
from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401
......@@ -470,6 +471,7 @@ __all__ = [ # noqa
'bmm',
'chunk',
'tolist',
'tensordot',
'greater_than',
'shard_index',
'argsort',
......
......@@ -1038,3 +1038,4 @@ if(WITH_GPU OR WITH_ROCM)
endif()
set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120)
set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400)
set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import paddle.fluid.core as core
import numpy as np
import itertools as it
np.set_printoptions(threshold=np.inf)
def tensordot_np(x, y, axes):
if isinstance(axes, paddle.fluid.framework.Variable):
axes = axes.tolist()
# np.tensordot does not support empty axes
if not axes:
axes = 0
if (isinstance(axes, (tuple, list))):
if all(np.issubdtype(type(i), np.integer) for i in axes):
axes = [axes, axes]
else:
axes_x = axes[0]
if len(axes) > 1:
axes_y = axes[1]
else:
axes_y = axes_x
len_axes_x, len_axes_y = len(axes_x), len(axes_y)
if len_axes_x < len_axes_y:
axes_x = axes_x + axes_y[len_axes_x:]
elif len_axes_y < len_axes_x:
axes_y = axes_y + axes_x[len_axes_y:]
axes = [axes_x, axes_y]
# np.tensordot does not support broadcast
if (isinstance(axes, (tuple, list))):
axes_x, axes_y = axes
else:
axes_x = list(range(x.ndim - axes, x.ndim))
axes_y = list(range(axes))
shape_x, shape_y = list(np.shape(x)), list(np.shape(y))
for i in range(len(axes_x)):
dim_x, dim_y = axes_x[i], axes_y[i]
sx, sy = shape_x[dim_x], shape_y[dim_y]
if sx == 1:
shape_y[dim_y] = 1
y = np.sum(y, dim_y)
y = np.reshape(y, shape_y)
elif sy == 1:
shape_x[dim_x] = 1
x = np.sum(x, dim_x)
x = np.reshape(x, shape_x)
return np.tensordot(x, y, axes)
class TestTensordotAPI(unittest.TestCase):
def setUp(self):
self.set_dtype()
self.set_input_shape()
self.set_input_data()
def set_dtype(self):
self.dtype = np.float32
def set_input_shape(self):
self.x_shape = [5, 5, 5, 5]
self.y_shape = [5, 5, 5, 5]
def set_input_data(self):
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.y = np.random.random(self.y_shape).astype(self.dtype)
self.all_axes = [2]
def run_dygraph(self, place):
paddle.disable_static()
x = paddle.to_tensor(self.x, place=place)
y = paddle.to_tensor(self.y, place=place)
paddle_res = paddle.tensordot(x, y, self.axes)
np_res = tensordot_np(self.x, self.y, self.axes)
np.testing.assert_allclose(paddle_res, np_res, rtol=1e-6)
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.static.data(
name='x', shape=self.x_shape, dtype=self.dtype)
y = paddle.static.data(
name='y', shape=self.y_shape, dtype=self.dtype)
z = paddle.tensordot(x, y, self.axes)
exe = paddle.static.Executor(place)
paddle_res = exe.run(feed={'x': self.x,
'y': self.y},
fetch_list=[z])
np_res = tensordot_np(self.x, self.y, self.axes)
np.testing.assert_allclose(paddle_res[0], np_res, rtol=1e-6)
def test_cases(self):
self.all_axes = []
axial_index = range(4)
all_permutations = list(it.permutations(axial_index, 0)) + list(
it.permutations(axial_index, 1)) + list(
it.permutations(axial_index, 2)) + list(
it.permutations(axial_index, 3)) + list(
it.permutations(axial_index, 4))
self.all_axes.extend(list(i) for i in all_permutations)
for axes_x in all_permutations:
for axes_y in all_permutations:
if len(axes_x) < len(axes_y):
supplementary_axes_x = axes_x + axes_y[len(axes_x):]
if any(
supplementary_axes_x.count(i) > 1
for i in supplementary_axes_x):
continue
elif len(axes_y) < len(axes_x):
supplementary_axes_y = axes_y + axes_x[len(axes_y):]
if any(
supplementary_axes_y.count(i) > 1
for i in supplementary_axes_y):
continue
self.all_axes.append([list(axes_x), list(axes_y)])
self.all_axes.extend(range(5))
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for axes in self.all_axes:
self.axes = axes
for place in places:
self.run_dygraph(place)
self.run_static(place)
class TestTensordotAPIFloat64(TestTensordotAPI):
def set_dtype(self):
self.dtype = np.float64
class TestTensordotAPIAxesType(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [3, 4, 4]
self.y_shape = [4, 4, 5]
def test_cases(self):
self.all_axes = [
0, 1, 2, (1, ), [1], ((1, ), ), ([1], ), ((2, 1), (0, )), (
(1, 2), (0, 1)), ([1, 2], [0, 1]), ([1, 2], [0, 1]),
[[1, 2], [0, 1]]
]
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for axes in self.all_axes:
self.axes = axes
for place in places:
self.run_dygraph(place)
self.run_static(place)
# The 'axes' with type 'Tensor' in tensordot is not available in static mode
paddle.disable_static()
for place in places:
self.all_axes = [
paddle.to_tensor([1]), (paddle.to_tensor([1])),
(paddle.to_tensor([1, 2]), paddle.to_tensor([0, 1])),
[paddle.to_tensor([1, 2]), paddle.to_tensor([0, 1])],
paddle.to_tensor([[1, 2], [0, 1]])
]
for axes in self.all_axes:
self.axes = axes
for place in places:
self.run_dygraph(place)
def test_error(self):
self.all_axes = [[[[0], [1]]], 0.1, -1, 100, [[1, 2], [0, 0]],
[[1, 2], [0, -1]], [0, 1, 2, 3]]
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
for axes in self.all_axes:
with self.assertRaises(BaseException):
paddle.tensordot(x, y, axes)
class TestTensordotAPIAxesTypeFloat64(TestTensordotAPIAxesType):
def set_dtype(self):
self.dtype = np.float64
class TestTensordotAPIBroadcastCase1(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [1, 1, 1, 5]
self.y_shape = [1, 5, 1, 1]
class TestTensordotAPIBroadcastCase2(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [1, 5, 5, 5]
self.y_shape = [1, 1, 1, 5]
class TestTensordotAPIBroadcastCase3(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [5, 5, 5, 1]
self.y_shape = [5, 5, 1, 5]
class TestTensordotAPIBroadcastCase4(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [5, 5, 5, 1]
self.y_shape = [1, 1, 1, 1]
class TestTensordotAPIBroadcastCase5(TestTensordotAPI):
def set_input_shape(self):
self.x_shape = [1, 1, 5, 5]
self.y_shape = [5, 5, 1, 5]
if __name__ == "__main__":
unittest.main()
......@@ -105,6 +105,7 @@ from .manipulation import flip # noqa: F401
from .manipulation import unbind # noqa: F401
from .manipulation import roll # noqa: F401
from .manipulation import chunk # noqa: F401
from .manipulation import tensordot # noqa: F401
from .math import abs # noqa: F401
from .math import acos # noqa: F401
from .math import asin # noqa: F401
......@@ -346,6 +347,7 @@ tensor_method_func = [ #noqa
'slice',
'split',
'chunk',
'tensordot',
'squeeze',
'squeeze_',
'stack',
......
......@@ -2173,3 +2173,211 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
return paddle.fluid.layers.strided_slice(
input=x, axes=axes, starts=starts, ends=ends, strides=strides)
def tensordot(x, y, axes=2, name=None):
r"""
This function computes a contraction, which sum the product of elements from two tensors along the given axes.
Args:
x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``.
y (Tensor): The right tensor for contraction with the same data type as ``x``.
axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``.
1. It could be a non-negative integer ``n``,
in which the function will sum over the last ``n`` axes of ``x`` and the first ``n`` axes of ``y`` in order.
2. It could be a 1-d tuple or list with data type ``int``, in which ``x`` and ``y`` will be contracted along the same given axes.
For example, ``axes`` =[0, 1] applies contraction along the first two axes for ``x`` and the first two axes for ``y``.
3. It could be a tuple or list containing one or two 1-d tuple|list|Tensor with data type ``int``.
When containing one tuple|list|Tensor, the data in tuple|list|Tensor specified the same axes for ``x`` and ``y`` to contract.
When containing two tuple|list|Tensor, the first will be applied to ``x`` and the second to ``y``.
When containing more than two tuple|list|Tensor, only the first two axis sequences will be used while the others will be ignored.
4. It could be a tensor, in which the ``axes`` tensor will be translated to a python list
and applied the same rules described above to determine the contraction axes.
Note that the ``axes`` with Tensor type is ONLY available in Dygraph mode.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Return:
Output (Tensor): The contraction result with the same data type as ``x`` and ``y``.
In general, :math:`output.ndim = x.ndim + y.ndim - 2 \times n_{axes}`, where :math:`n_{axes}` denotes the number of axes to be contracted.
NOTES:
1. This function supports tensor broadcast,
the size in the corresponding dimensions of ``x`` and ``y`` should be equal, or applies to the broadcast rules.
2. This function also supports axes expansion,
when the two given axis sequences for ``x`` and ``y`` are of different lengths,
the shorter sequence will expand the same axes as the longer one at the end.
For example, if ``axes`` =[[0, 1, 2, 3], [1, 0]],
the axis sequence for ``x`` is [0, 1, 2, 3],
while the corresponding axis sequences for ``y`` will be expanded from [1, 0] to [1, 0, 2, 3].
Examples:
.. code-block:: python
import paddle
data_type = 'float64'
# For two 2-d tensor x and y, the case axes=0 is equivalent to outer product.
# Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases.
x = paddle.arange(4, dtype=data_type).reshape([2, 2])
y = paddle.arange(4, dtype=data_type).reshape([2, 2])
z = paddle.tensordot(x, y, axes=0)
# z = [[[[0., 0.],
# [0., 0.]],
#
# [[0., 1.],
# [2., 3.]]],
#
#
# [[[0., 2.],
# [4., 6.]],
#
# [[0., 3.],
# [6., 9.]]]]
# For two 1-d tensor x and y, the case axes=1 is equivalent to inner product.
x = paddle.arange(10, dtype=data_type)
y = paddle.arange(10, dtype=data_type)
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.dot(x, y)
# z1 = z2 = [285.]
# For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication.
x = paddle.arange(6, dtype=data_type).reshape([2, 3])
y = paddle.arange(12, dtype=data_type).reshape([3, 4])
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.matmul(x, y)
# z1 = z2 = [[20., 23., 26., 29.],
# [56., 68., 80., 92.]]
# When axes is a 1-d int list, x and y will be contracted along the same given axes.
# Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]].
x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4])
y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4])
z = paddle.tensordot(x, y, axes=[1, 2])
# z = [[506. , 1298., 2090.],
# [1298., 3818., 6338.]]
# When axes is a list containing two 1-d int list, the first will be applied to x and the second to y.
x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5])
y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2])
z = paddle.tensordot(x, y, axes=([1, 0], [0, 1]))
# z = [[4400., 4730.],
# [4532., 4874.],
# [4664., 5018.],
# [4796., 5162.],
# [4928., 5306.]]
# Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]].
x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6])
y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6])
z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]])
# z = [[23217330., 24915630., 26613930., 28312230.],
# [24915630., 26775930., 28636230., 30496530.],
# [26613930., 28636230., 30658530., 32680830.],
# [28312230., 30496530., 32680830., 34865130.]]
"""
op_type = 'tensordot'
input_dtype = ['float32', 'float64']
check_variable_and_dtype(x, 'x', input_dtype, op_type)
check_variable_and_dtype(y, 'y', input_dtype, op_type)
check_type(axes, 'axes', (int, tuple, list, Variable), op_type)
def _var_to_list(var):
if in_dygraph_mode():
return tolist(var)
raise TypeError(
"The 'axes' with type 'Tensor' in " + op_type +
" is not available in static graph mode, "
"please convert its type to int|Tuple|List, or use dynamic graph mode."
)
axes_x = []
axes_y = []
if np.issubdtype(type(axes), np.integer):
assert axes >= 0, (
"The 'axes' in " + op_type +
f" should not be negative, but received axes={axes}.")
axes_x = range(x.ndim - axes, x.ndim)
axes_y = range(axes)
else:
if isinstance(axes, Variable):
axes = _var_to_list(axes)
if not axes or np.issubdtype(type(axes[0]), np.integer):
axes_x = axes
else:
axes_x = axes[0]
if len(axes) > 1:
axes_y = axes[1]
if isinstance(axes_x, Variable):
axes_x = _var_to_list(axes_x)
if isinstance(axes_y, Variable):
axes_y = _var_to_list(axes_y)
axes_x, axes_y = list(axes_x), list(axes_y)
len_axes_x, len_axes_y = len(axes_x), len(axes_y)
if len_axes_x < len_axes_y:
axes_x.extend(axes_y[len_axes_x:])
elif len_axes_y < len_axes_x:
axes_y.extend(axes_x[len_axes_y:])
shape_x, shape_y = list(x.shape), list(y.shape)
need_contracted_dim_x = np.zeros((x.ndim), dtype=bool)
need_contracted_dim_y = np.zeros((y.ndim), dtype=bool)
contraction_size = 1
for i in range(len(axes_x)):
dim_x, dim_y = axes_x[i], axes_y[i]
sx, sy = shape_x[dim_x], shape_y[dim_y]
if sx == 1:
shape_y[dim_y] = 1
y = y.sum(dim_y).reshape(shape_y)
elif sy == 1:
shape_x[dim_x] = 1
x = x.sum(dim_x).reshape(shape_x)
else:
assert sx == sy, "The dimensional size for 'x' and 'y' in " + op_type + f" should match each other, but 'x' has size {sx} in dim {dim_x} while 'y' has size {sy} in dim {dim_y}."
need_contracted_dim_x[dim_x] = True
need_contracted_dim_y[dim_y] = True
contraction_size *= shape_x[dim_x]
perm_x = []
perm_y = []
shape_out = []
not_contraction_size_x = 1
not_contraction_size_y = 1
for i in range(x.ndim):
if not need_contracted_dim_x[i]:
perm_x.append(i)
shape_out.append(shape_x[i])
not_contraction_size_x *= shape_x[i]
perm_x.extend(axes_x)
perm_y.extend(axes_y)
for i in range(y.ndim):
if not need_contracted_dim_y[i]:
perm_y.append(i)
shape_out.append(shape_y[i])
not_contraction_size_y *= shape_y[i]
if not shape_out:
shape_out = [1]
x = x.transpose(perm=perm_x).reshape(
[not_contraction_size_x, contraction_size])
y = y.transpose(perm=perm_y).reshape(
[contraction_size, not_contraction_size_y])
out = x.matmul(y).reshape(shape_out)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册