未验证 提交 3304e345 编写于 作者: A Asthestarsfalll 提交者: GitHub

【PaddlePaddle Hackathon 3 No.20】为 Paddle 新增 vsplit API (#44853)

* add paddle vsplit api

* update unittest and fix a typo

* update

* add vsplit to __all__

* update unit test and description of x

* fix typo
上级 417ce3c1
...@@ -165,6 +165,7 @@ from .tensor.manipulation import shard_index # noqa: F401 ...@@ -165,6 +165,7 @@ from .tensor.manipulation import shard_index # noqa: F401
from .tensor.manipulation import slice # noqa: F401 from .tensor.manipulation import slice # noqa: F401
from .tensor.manipulation import crop # noqa: F401 from .tensor.manipulation import crop # noqa: F401
from .tensor.manipulation import split # noqa: F401 from .tensor.manipulation import split # noqa: F401
from .tensor.manipulation import vsplit # noqa: F401
from .tensor.manipulation import squeeze # noqa: F401 from .tensor.manipulation import squeeze # noqa: F401
from .tensor.manipulation import squeeze_ # noqa: F401 from .tensor.manipulation import squeeze_ # noqa: F401
from .tensor.manipulation import stack # noqa: F401 from .tensor.manipulation import stack # noqa: F401
...@@ -457,6 +458,7 @@ __all__ = [ # noqa ...@@ -457,6 +458,7 @@ __all__ = [ # noqa
'searchsorted', 'searchsorted',
'bucketize', 'bucketize',
'split', 'split',
'vsplit',
'logical_and', 'logical_and',
'full_like', 'full_like',
'less_than', 'less_than',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
def func_ref(func, x, num_or_sections):
# Convert the num_or_sections in paddle to indices_or_sections in numpy
# Do not support -1
if isinstance(num_or_sections, int):
indices_or_sections = num_or_sections
else:
indices_or_sections = np.cumsum(num_or_sections)[:-1]
return func(x, indices_or_sections)
# TODO: add other split API, such as dsplit、hsplit
test_list = [
(paddle.vsplit, np.vsplit),
]
class TestSplitsAPI(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.set_input()
def set_input(self):
self.shape = [4, 5, 2]
self.num_or_sections = 2
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
for func, func_type in test_list:
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype)
out = func(x, self.num_or_sections)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = func_ref(func_type, self.x_np, self.num_or_sections)
for n, p in zip(out_ref, res):
np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
for func, func_type in test_list:
out = func(x, self.num_or_sections)
out_ref = func_ref(func_type, self.x_np, self.num_or_sections)
for n, p in zip(out_ref, out):
np.testing.assert_allclose(n,
p.numpy(),
rtol=self.rtol,
atol=self.atol)
paddle.enable_static()
class TestSplitsSections(TestSplitsAPI):
"""
Test num_or_sections which is a list and date type is float64.
"""
def set_input(self):
self.shape = [6, 2, 4]
self.num_or_sections = [2, 1, 3]
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
class TestSplitsFloat32(TestSplitsAPI):
"""
Test num_or_sections which is an integer and data type is float32.
"""
def set_input(self):
self.shape = [2, 3, 4]
self.num_or_sections = 2
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float32')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
class TestSplitsInt32(TestSplitsAPI):
"""
Test data type int32.
"""
def set_input(self):
self.shape = [5, 1, 2]
self.num_or_sections = 5
self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
class TestSplitsInt64(TestSplitsAPI):
"""
Test data type int64.
"""
def set_input(self):
self.shape = [4, 3, 2]
self.num_or_sections = 2
self.x_np = np.random.uniform(-1, 1, self.shape).astype('int64')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
class TestSplitsCPU(TestSplitsAPI):
"""
Test cpu place and num_or_sections which is a tuple.
"""
def set_input(self):
self.shape = [8, 2, 3, 5]
self.num_or_sections = (2, 3, 3)
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64')
self.place = paddle.CPUPlace()
class TestSplitsError(unittest.TestCase):
"""
Test the situation that input shape less than 2.
"""
def setUp(self):
self.num_or_sections = 1
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_error(self):
paddle.enable_static()
for func, _ in test_list:
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [5], 'float32')
self.assertRaises(ValueError, func, x, self.num_or_sections)
def test_dygraph_error(self):
paddle.disable_static(self.place)
for func, _ in test_list:
x_np = np.random.randn(2)
x = paddle.to_tensor(x_np, dtype='float64')
self.assertRaises(ValueError, func, x, self.num_or_sections)
if __name__ == '__main__':
unittest.main()
...@@ -106,6 +106,7 @@ from .manipulation import scatter_nd # noqa: F401 ...@@ -106,6 +106,7 @@ from .manipulation import scatter_nd # noqa: F401
from .manipulation import shard_index # noqa: F401 from .manipulation import shard_index # noqa: F401
from .manipulation import slice # noqa: F401 from .manipulation import slice # noqa: F401
from .manipulation import split # noqa: F401 from .manipulation import split # noqa: F401
from .manipulation import vsplit # noqa: F401
from .manipulation import squeeze # noqa: F401 from .manipulation import squeeze # noqa: F401
from .manipulation import squeeze_ # noqa: F401 from .manipulation import squeeze_ # noqa: F401
from .manipulation import stack # noqa: F401 from .manipulation import stack # noqa: F401
...@@ -428,6 +429,7 @@ tensor_method_func = [ # noqa ...@@ -428,6 +429,7 @@ tensor_method_func = [ # noqa
'shard_index', 'shard_index',
'slice', 'slice',
'split', 'split',
'vsplit',
'chunk', 'chunk',
'tensordot', 'tensordot',
'squeeze', 'squeeze',
......
...@@ -1926,6 +1926,48 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1926,6 +1926,48 @@ def split(x, num_or_sections, axis=0, name=None):
return outs return outs
def vsplit(x, num_or_sections, name=None):
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``.
Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
The length of the list must not be larger than the ``x`` 's size of axis 0.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
list[Tensor], The list of segmented Tensors.
Example:
.. code-block:: python
import paddle
# x is a Tensor of shape [8, 6, 7]
x = paddle.rand([8, 6, 7])
out0, out1, out2 = paddle.vsplit(x, num_or_sections=2)
print(out0.shape) # [4, 6, 7]
print(out1.shape) # [4, 6, 7]
out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4])
print(out0.shape) # [1, 6, 7]
print(out1.shape) # [3, 6, 7]
print(out2.shape) # [4, 6, 7]
out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1])
print(out0.shape) # [2, 6, 7]
print(out1.shape) # [3, 6, 7]
print(out2.shape) # [3, 6, 7]
"""
if x.ndim < 2:
raise ValueError(
"The input tensor's dimension must be greater than 1, but got {}".
format(x.ndim))
return split(x, num_or_sections, axis=0, name=name)
def squeeze(x, axis=None, name=None): def squeeze(x, axis=None, name=None):
""" """
Squeeze the dimension(s) of size 1 of input tensor x's shape. Squeeze the dimension(s) of size 1 of input tensor x's shape.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册