From 2bfee7d3fa94eacbd026c34fdc9980b6ae156a50 Mon Sep 17 00:00:00 2001 From: From00 Date: Mon, 25 Oct 2021 10:39:30 +0800 Subject: [PATCH] [cherry-pick] Add new API 'tensordot' (#36273) (#36454) * Add new API tensordot cherry-pick #36273 --- python/paddle/__init__.py | 2 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_tensordot.py | 238 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/manipulation.py | 208 +++++++++++++++ 5 files changed, 451 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_tensordot.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ad8640f6f55..9146c6224bf 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -152,6 +152,7 @@ from .tensor.manipulation import unbind # noqa: F401 from .tensor.manipulation import roll # noqa: F401 from .tensor.manipulation import chunk # noqa: F401 from .tensor.manipulation import tolist # noqa: F401 +from .tensor.manipulation import tensordot # noqa: F401 from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 @@ -470,6 +471,7 @@ __all__ = [ # noqa 'bmm', 'chunk', 'tolist', + 'tensordot', 'greater_than', 'shard_index', 'argsort', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index acc47d883fd..2bc62c28fd2 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1038,3 +1038,4 @@ if(WITH_GPU OR WITH_ROCM) endif() set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) +set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000) diff --git a/python/paddle/fluid/tests/unittests/test_tensordot.py b/python/paddle/fluid/tests/unittests/test_tensordot.py new file mode 100644 index 00000000000..29f3308988f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensordot.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import paddle.fluid.core as core +import numpy as np +import itertools as it + +np.set_printoptions(threshold=np.inf) + + +def tensordot_np(x, y, axes): + if isinstance(axes, paddle.fluid.framework.Variable): + axes = axes.tolist() + + # np.tensordot does not support empty axes + if not axes: + axes = 0 + if (isinstance(axes, (tuple, list))): + if all(np.issubdtype(type(i), np.integer) for i in axes): + axes = [axes, axes] + else: + axes_x = axes[0] + if len(axes) > 1: + axes_y = axes[1] + else: + axes_y = axes_x + len_axes_x, len_axes_y = len(axes_x), len(axes_y) + if len_axes_x < len_axes_y: + axes_x = axes_x + axes_y[len_axes_x:] + elif len_axes_y < len_axes_x: + axes_y = axes_y + axes_x[len_axes_y:] + axes = [axes_x, axes_y] + + # np.tensordot does not support broadcast + if (isinstance(axes, (tuple, list))): + axes_x, axes_y = axes + else: + axes_x = list(range(x.ndim - axes, x.ndim)) + axes_y = list(range(axes)) + shape_x, shape_y = list(np.shape(x)), list(np.shape(y)) + for i in range(len(axes_x)): + dim_x, dim_y = axes_x[i], axes_y[i] + sx, sy = shape_x[dim_x], shape_y[dim_y] + if sx == 1: + shape_y[dim_y] = 1 + y = np.sum(y, dim_y) + y = np.reshape(y, shape_y) + elif sy == 1: + shape_x[dim_x] = 1 + x = np.sum(x, dim_x) + x = np.reshape(x, shape_x) + + return np.tensordot(x, y, axes) + + +class TestTensordotAPI(unittest.TestCase): + def setUp(self): + self.set_dtype() + self.set_input_shape() + self.set_input_data() + + def set_dtype(self): + self.dtype = np.float32 + + def set_input_shape(self): + self.x_shape = [5, 5, 5, 5] + self.y_shape = [5, 5, 5, 5] + + def set_input_data(self): + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.all_axes = [2] + + def run_dygraph(self, place): + paddle.disable_static() + x = paddle.to_tensor(self.x, place=place) + y = paddle.to_tensor(self.y, place=place) + paddle_res = paddle.tensordot(x, y, self.axes) + np_res = tensordot_np(self.x, self.y, self.axes) + np.testing.assert_allclose(paddle_res, np_res, rtol=1e-6) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.static.data( + name='x', shape=self.x_shape, dtype=self.dtype) + y = paddle.static.data( + name='y', shape=self.y_shape, dtype=self.dtype) + z = paddle.tensordot(x, y, self.axes) + exe = paddle.static.Executor(place) + paddle_res = exe.run(feed={'x': self.x, + 'y': self.y}, + fetch_list=[z]) + np_res = tensordot_np(self.x, self.y, self.axes) + np.testing.assert_allclose(paddle_res[0], np_res, rtol=1e-6) + + def test_cases(self): + self.all_axes = [] + axial_index = range(4) + all_permutations = list(it.permutations(axial_index, 0)) + list( + it.permutations(axial_index, 1)) + list( + it.permutations(axial_index, 2)) + list( + it.permutations(axial_index, 3)) + list( + it.permutations(axial_index, 4)) + self.all_axes.extend(list(i) for i in all_permutations) + + for axes_x in all_permutations: + for axes_y in all_permutations: + if len(axes_x) < len(axes_y): + supplementary_axes_x = axes_x + axes_y[len(axes_x):] + if any( + supplementary_axes_x.count(i) > 1 + for i in supplementary_axes_x): + continue + elif len(axes_y) < len(axes_x): + supplementary_axes_y = axes_y + axes_x[len(axes_y):] + if any( + supplementary_axes_y.count(i) > 1 + for i in supplementary_axes_y): + continue + self.all_axes.append([list(axes_x), list(axes_y)]) + + self.all_axes.extend(range(5)) + + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for axes in self.all_axes: + self.axes = axes + for place in places: + self.run_dygraph(place) + self.run_static(place) + + +class TestTensordotAPIFloat64(TestTensordotAPI): + def set_dtype(self): + self.dtype = np.float64 + + +class TestTensordotAPIAxesType(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [3, 4, 4] + self.y_shape = [4, 4, 5] + + def test_cases(self): + self.all_axes = [ + 0, 1, 2, (1, ), [1], ((1, ), ), ([1], ), ((2, 1), (0, )), ( + (1, 2), (0, 1)), ([1, 2], [0, 1]), ([1, 2], [0, 1]), + [[1, 2], [0, 1]] + ] + + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for axes in self.all_axes: + self.axes = axes + for place in places: + self.run_dygraph(place) + self.run_static(place) + + # The 'axes' with type 'Tensor' in tensordot is not available in static mode + paddle.disable_static() + for place in places: + self.all_axes = [ + paddle.to_tensor([1]), (paddle.to_tensor([1])), + (paddle.to_tensor([1, 2]), paddle.to_tensor([0, 1])), + [paddle.to_tensor([1, 2]), paddle.to_tensor([0, 1])], + paddle.to_tensor([[1, 2], [0, 1]]) + ] + for axes in self.all_axes: + self.axes = axes + for place in places: + self.run_dygraph(place) + + def test_error(self): + self.all_axes = [[[[0], [1]]], 0.1, -1, 100, [[1, 2], [0, 0]], + [[1, 2], [0, -1]], [0, 1, 2, 3]] + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + for axes in self.all_axes: + with self.assertRaises(BaseException): + paddle.tensordot(x, y, axes) + + +class TestTensordotAPIAxesTypeFloat64(TestTensordotAPIAxesType): + def set_dtype(self): + self.dtype = np.float64 + + +class TestTensordotAPIBroadcastCase1(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [1, 1, 1, 5] + self.y_shape = [1, 5, 1, 1] + + +class TestTensordotAPIBroadcastCase2(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [1, 5, 5, 5] + self.y_shape = [1, 1, 1, 5] + + +class TestTensordotAPIBroadcastCase3(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [5, 5, 5, 1] + self.y_shape = [5, 5, 1, 5] + + +class TestTensordotAPIBroadcastCase4(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [5, 5, 5, 1] + self.y_shape = [1, 1, 1, 1] + + +class TestTensordotAPIBroadcastCase5(TestTensordotAPI): + def set_input_shape(self): + self.x_shape = [1, 1, 5, 5] + self.y_shape = [5, 5, 1, 5] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b5d79b60393..c8f897c2164 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -105,6 +105,7 @@ from .manipulation import flip # noqa: F401 from .manipulation import unbind # noqa: F401 from .manipulation import roll # noqa: F401 from .manipulation import chunk # noqa: F401 +from .manipulation import tensordot # noqa: F401 from .math import abs # noqa: F401 from .math import acos # noqa: F401 from .math import asin # noqa: F401 @@ -346,6 +347,7 @@ tensor_method_func = [ #noqa 'slice', 'split', 'chunk', + 'tensordot', 'squeeze', 'squeeze_', 'stack', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4129a1060da..5f7588cb2a9 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2173,3 +2173,211 @@ def strided_slice(x, axes, starts, ends, strides, name=None): return paddle.fluid.layers.strided_slice( input=x, axes=axes, starts=starts, ends=ends, strides=strides) + + +def tensordot(x, y, axes=2, name=None): + r""" + This function computes a contraction, which sum the product of elements from two tensors along the given axes. + + Args: + x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``. + y (Tensor): The right tensor for contraction with the same data type as ``x``. + axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``. + + 1. It could be a non-negative integer ``n``, + in which the function will sum over the last ``n`` axes of ``x`` and the first ``n`` axes of ``y`` in order. + + 2. It could be a 1-d tuple or list with data type ``int``, in which ``x`` and ``y`` will be contracted along the same given axes. + For example, ``axes`` =[0, 1] applies contraction along the first two axes for ``x`` and the first two axes for ``y``. + + 3. It could be a tuple or list containing one or two 1-d tuple|list|Tensor with data type ``int``. + When containing one tuple|list|Tensor, the data in tuple|list|Tensor specified the same axes for ``x`` and ``y`` to contract. + When containing two tuple|list|Tensor, the first will be applied to ``x`` and the second to ``y``. + When containing more than two tuple|list|Tensor, only the first two axis sequences will be used while the others will be ignored. + + 4. It could be a tensor, in which the ``axes`` tensor will be translated to a python list + and applied the same rules described above to determine the contraction axes. + Note that the ``axes`` with Tensor type is ONLY available in Dygraph mode. + name(str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Return: + Output (Tensor): The contraction result with the same data type as ``x`` and ``y``. + In general, :math:`output.ndim = x.ndim + y.ndim - 2 \times n_{axes}`, where :math:`n_{axes}` denotes the number of axes to be contracted. + + NOTES: + 1. This function supports tensor broadcast, + the size in the corresponding dimensions of ``x`` and ``y`` should be equal, or applies to the broadcast rules. + 2. This function also supports axes expansion, + when the two given axis sequences for ``x`` and ``y`` are of different lengths, + the shorter sequence will expand the same axes as the longer one at the end. + For example, if ``axes`` =[[0, 1, 2, 3], [1, 0]], + the axis sequence for ``x`` is [0, 1, 2, 3], + while the corresponding axis sequences for ``y`` will be expanded from [1, 0] to [1, 0, 2, 3]. + + Examples: + .. code-block:: python + + import paddle + + data_type = 'float64' + + # For two 2-d tensor x and y, the case axes=0 is equivalent to outer product. + # Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases. + x = paddle.arange(4, dtype=data_type).reshape([2, 2]) + y = paddle.arange(4, dtype=data_type).reshape([2, 2]) + z = paddle.tensordot(x, y, axes=0) + # z = [[[[0., 0.], + # [0., 0.]], + # + # [[0., 1.], + # [2., 3.]]], + # + # + # [[[0., 2.], + # [4., 6.]], + # + # [[0., 3.], + # [6., 9.]]]] + + + # For two 1-d tensor x and y, the case axes=1 is equivalent to inner product. + x = paddle.arange(10, dtype=data_type) + y = paddle.arange(10, dtype=data_type) + z1 = paddle.tensordot(x, y, axes=1) + z2 = paddle.dot(x, y) + # z1 = z2 = [285.] + + + # For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication. + x = paddle.arange(6, dtype=data_type).reshape([2, 3]) + y = paddle.arange(12, dtype=data_type).reshape([3, 4]) + z1 = paddle.tensordot(x, y, axes=1) + z2 = paddle.matmul(x, y) + # z1 = z2 = [[20., 23., 26., 29.], + # [56., 68., 80., 92.]] + + + # When axes is a 1-d int list, x and y will be contracted along the same given axes. + # Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]]. + x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4]) + y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4]) + z = paddle.tensordot(x, y, axes=[1, 2]) + # z = [[506. , 1298., 2090.], + # [1298., 3818., 6338.]] + + + # When axes is a list containing two 1-d int list, the first will be applied to x and the second to y. + x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5]) + y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2]) + z = paddle.tensordot(x, y, axes=([1, 0], [0, 1])) + # z = [[4400., 4730.], + # [4532., 4874.], + # [4664., 5018.], + # [4796., 5162.], + # [4928., 5306.]] + + + # Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]]. + x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6]) + y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6]) + z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]]) + # z = [[23217330., 24915630., 26613930., 28312230.], + # [24915630., 26775930., 28636230., 30496530.], + # [26613930., 28636230., 30658530., 32680830.], + # [28312230., 30496530., 32680830., 34865130.]] + """ + op_type = 'tensordot' + input_dtype = ['float32', 'float64'] + + check_variable_and_dtype(x, 'x', input_dtype, op_type) + check_variable_and_dtype(y, 'y', input_dtype, op_type) + check_type(axes, 'axes', (int, tuple, list, Variable), op_type) + + def _var_to_list(var): + if in_dygraph_mode(): + return tolist(var) + raise TypeError( + "The 'axes' with type 'Tensor' in " + op_type + + " is not available in static graph mode, " + "please convert its type to int|Tuple|List, or use dynamic graph mode." + ) + + axes_x = [] + axes_y = [] + if np.issubdtype(type(axes), np.integer): + assert axes >= 0, ( + "The 'axes' in " + op_type + + f" should not be negative, but received axes={axes}.") + axes_x = range(x.ndim - axes, x.ndim) + axes_y = range(axes) + else: + if isinstance(axes, Variable): + axes = _var_to_list(axes) + + if not axes or np.issubdtype(type(axes[0]), np.integer): + axes_x = axes + else: + axes_x = axes[0] + if len(axes) > 1: + axes_y = axes[1] + + if isinstance(axes_x, Variable): + axes_x = _var_to_list(axes_x) + if isinstance(axes_y, Variable): + axes_y = _var_to_list(axes_y) + + axes_x, axes_y = list(axes_x), list(axes_y) + len_axes_x, len_axes_y = len(axes_x), len(axes_y) + if len_axes_x < len_axes_y: + axes_x.extend(axes_y[len_axes_x:]) + elif len_axes_y < len_axes_x: + axes_y.extend(axes_x[len_axes_y:]) + + shape_x, shape_y = list(x.shape), list(y.shape) + need_contracted_dim_x = np.zeros((x.ndim), dtype=bool) + need_contracted_dim_y = np.zeros((y.ndim), dtype=bool) + contraction_size = 1 + for i in range(len(axes_x)): + dim_x, dim_y = axes_x[i], axes_y[i] + sx, sy = shape_x[dim_x], shape_y[dim_y] + if sx == 1: + shape_y[dim_y] = 1 + y = y.sum(dim_y).reshape(shape_y) + elif sy == 1: + shape_x[dim_x] = 1 + x = x.sum(dim_x).reshape(shape_x) + else: + assert sx == sy, "The dimensional size for 'x' and 'y' in " + op_type + f" should match each other, but 'x' has size {sx} in dim {dim_x} while 'y' has size {sy} in dim {dim_y}." + + need_contracted_dim_x[dim_x] = True + need_contracted_dim_y[dim_y] = True + contraction_size *= shape_x[dim_x] + + perm_x = [] + perm_y = [] + shape_out = [] + not_contraction_size_x = 1 + not_contraction_size_y = 1 + for i in range(x.ndim): + if not need_contracted_dim_x[i]: + perm_x.append(i) + shape_out.append(shape_x[i]) + not_contraction_size_x *= shape_x[i] + perm_x.extend(axes_x) + perm_y.extend(axes_y) + for i in range(y.ndim): + if not need_contracted_dim_y[i]: + perm_y.append(i) + shape_out.append(shape_y[i]) + not_contraction_size_y *= shape_y[i] + + if not shape_out: + shape_out = [1] + + x = x.transpose(perm=perm_x).reshape( + [not_contraction_size_x, contraction_size]) + y = y.transpose(perm=perm_y).reshape( + [contraction_size, not_contraction_size_y]) + out = x.matmul(y).reshape(shape_out) + return out -- GitLab