From 3304e3454d7492ca87fae235a62ce3ae68441455 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <72954905+Asthestarsfalll@users.noreply.github.com> Date: Fri, 9 Sep 2022 16:50:50 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=203=20No.20?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E6=96=B0=E5=A2=9E=20vsplit=20API?= =?UTF-8?q?=20(#44853)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add paddle vsplit api * update unittest and fix a typo * update * add vsplit to __all__ * update unit test and description of x * fix typo --- python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_splits_api.py | 168 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/manipulation.py | 42 +++++ 4 files changed, 214 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_splits_api.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index dc55260f2c..4e983f4372 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -165,6 +165,7 @@ from .tensor.manipulation import shard_index # noqa: F401 from .tensor.manipulation import slice # noqa: F401 from .tensor.manipulation import crop # noqa: F401 from .tensor.manipulation import split # noqa: F401 +from .tensor.manipulation import vsplit # noqa: F401 from .tensor.manipulation import squeeze # noqa: F401 from .tensor.manipulation import squeeze_ # noqa: F401 from .tensor.manipulation import stack # noqa: F401 @@ -457,6 +458,7 @@ __all__ = [ # noqa 'searchsorted', 'bucketize', 'split', + 'vsplit', 'logical_and', 'full_like', 'less_than', diff --git a/python/paddle/fluid/tests/unittests/test_splits_api.py b/python/paddle/fluid/tests/unittests/test_splits_api.py new file mode 100644 index 0000000000..4b6254e266 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_splits_api.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +def func_ref(func, x, num_or_sections): + # Convert the num_or_sections in paddle to indices_or_sections in numpy + # Do not support -1 + if isinstance(num_or_sections, int): + indices_or_sections = num_or_sections + else: + indices_or_sections = np.cumsum(num_or_sections)[:-1] + return func(x, indices_or_sections) + + +# TODO: add other split API, such as dsplit、hsplit +test_list = [ + (paddle.vsplit, np.vsplit), +] + + +class TestSplitsAPI(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.set_input() + + def set_input(self): + self.shape = [4, 5, 2] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + for func, func_type in test_list: + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = func(x, self.num_or_sections) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = func_ref(func_type, self.x_np, self.num_or_sections) + for n, p in zip(out_ref, res): + np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + for func, func_type in test_list: + out = func(x, self.num_or_sections) + out_ref = func_ref(func_type, self.x_np, self.num_or_sections) + for n, p in zip(out_ref, out): + np.testing.assert_allclose(n, + p.numpy(), + rtol=self.rtol, + atol=self.atol) + paddle.enable_static() + + +class TestSplitsSections(TestSplitsAPI): + """ + Test num_or_sections which is a list and date type is float64. + """ + + def set_input(self): + self.shape = [6, 2, 4] + self.num_or_sections = [2, 1, 3] + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsFloat32(TestSplitsAPI): + """ + Test num_or_sections which is an integer and data type is float32. + """ + + def set_input(self): + self.shape = [2, 3, 4] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float32') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsInt32(TestSplitsAPI): + """ + Test data type int32. + """ + + def set_input(self): + self.shape = [5, 1, 2] + self.num_or_sections = 5 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsInt64(TestSplitsAPI): + """ + Test data type int64. + """ + + def set_input(self): + self.shape = [4, 3, 2] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('int64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsCPU(TestSplitsAPI): + """ + Test cpu place and num_or_sections which is a tuple. + """ + + def set_input(self): + self.shape = [8, 2, 3, 5] + self.num_or_sections = (2, 3, 3) + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CPUPlace() + + +class TestSplitsError(unittest.TestCase): + """ + Test the situation that input shape less than 2. + """ + + def setUp(self): + self.num_or_sections = 1 + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_error(self): + paddle.enable_static() + for func, _ in test_list: + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [5], 'float32') + self.assertRaises(ValueError, func, x, self.num_or_sections) + + def test_dygraph_error(self): + paddle.disable_static(self.place) + for func, _ in test_list: + x_np = np.random.randn(2) + x = paddle.to_tensor(x_np, dtype='float64') + self.assertRaises(ValueError, func, x, self.num_or_sections) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 58dfa26cfe..ba7dd5d0ce 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -106,6 +106,7 @@ from .manipulation import scatter_nd # noqa: F401 from .manipulation import shard_index # noqa: F401 from .manipulation import slice # noqa: F401 from .manipulation import split # noqa: F401 +from .manipulation import vsplit # noqa: F401 from .manipulation import squeeze # noqa: F401 from .manipulation import squeeze_ # noqa: F401 from .manipulation import stack # noqa: F401 @@ -428,6 +429,7 @@ tensor_method_func = [ # noqa 'shard_index', 'slice', 'split', + 'vsplit', 'chunk', 'tensordot', 'squeeze', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5e05a93e90..d3dcb60ec5 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1926,6 +1926,48 @@ def split(x, num_or_sections, axis=0, name=None): return outs +def vsplit(x, num_or_sections, name=None): + """ + Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``. + + Args: + x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. + num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` + indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. + If ``num_or_sections`` is a list or tuple, the length of it indicates the number of + sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. + The length of the list must not be larger than the ``x`` 's size of axis 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Returns: + list[Tensor], The list of segmented Tensors. + + Example: + .. code-block:: python + + import paddle + + # x is a Tensor of shape [8, 6, 7] + x = paddle.rand([8, 6, 7]) + out0, out1, out2 = paddle.vsplit(x, num_or_sections=2) + print(out0.shape) # [4, 6, 7] + print(out1.shape) # [4, 6, 7] + out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) + print(out0.shape) # [1, 6, 7] + print(out1.shape) # [3, 6, 7] + print(out2.shape) # [4, 6, 7] + out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) + print(out0.shape) # [2, 6, 7] + print(out1.shape) # [3, 6, 7] + print(out2.shape) # [3, 6, 7] + """ + if x.ndim < 2: + raise ValueError( + "The input tensor's dimension must be greater than 1, but got {}". + format(x.ndim)) + return split(x, num_or_sections, axis=0, name=name) + + def squeeze(x, axis=None, name=None): """ Squeeze the dimension(s) of size 1 of input tensor x's shape. -- GitLab