未验证 提交 9cd0a5bc 编写于 作者: A Ainavo 提交者: GitHub

[Hackthon 4th No.7] add paddle.unflatten API (#53055)

上级 bafc3469
...@@ -198,6 +198,7 @@ from .tensor.manipulation import moveaxis # noqa: F401 ...@@ -198,6 +198,7 @@ from .tensor.manipulation import moveaxis # noqa: F401
from .tensor.manipulation import repeat_interleave # noqa: F401 from .tensor.manipulation import repeat_interleave # noqa: F401
from .tensor.manipulation import index_add # noqa: F401 from .tensor.manipulation import index_add # noqa: F401
from .tensor.manipulation import index_add_ # noqa: F401 from .tensor.manipulation import index_add_ # noqa: F401
from .tensor.manipulation import unflatten # noqa: F401
from .tensor.math import abs # noqa: F401 from .tensor.math import abs # noqa: F401
from .tensor.math import acos # noqa: F401 from .tensor.math import acos # noqa: F401
from .tensor.math import asin # noqa: F401 from .tensor.math import asin # noqa: F401
...@@ -691,5 +692,6 @@ __all__ = [ # noqa ...@@ -691,5 +692,6 @@ __all__ = [ # noqa
'cumulative_trapezoid', 'cumulative_trapezoid',
'polar', 'polar',
'vander', 'vander',
'unflatten',
'nextafter', 'nextafter',
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
def numpy_unflatten(x, axis, shape):
if isinstance(shape, (list, tuple)):
if len(shape) == 0:
raise ValueError("The input for shape cannot be empty.")
if isinstance(shape, list) or isinstance(shape, tuple):
if np.min(shape) < -1:
raise ValueError(f"invalid shape dimension {np.min(shape)}.")
if shape.count(-1) > 1:
raise ValueError("The shape can contain only one -1.")
elif shape.count(-1) == 1:
list(shape)[shape.index(-1)] = x.shape[axis] / abs(
np.prod(shape)
)
else:
sizes = np.prod(shape)
if sizes != x.shape[axis]:
raise ValueError(
"The product of the elements in shape{} is not equal to {}.".format(
shape, x.shape[axis]
)
)
else:
raise TypeError(
"The data type of x should be one of ['List', 'Tuple', 'Tensor'], but got {}".format(
type(shape)
)
)
length = len(x.shape)
if axis < 0:
axis = axis + length
new_shape = x.shape[:axis] + tuple(shape) + x.shape[axis + 1 :]
x = x.reshape(new_shape)
return x
class TestUnflattenAPI(unittest.TestCase):
def set_args(self):
self.x = np.random.rand(4, 6, 16)
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
def get_output(self):
self.output = self.ref_api(self.x, self.axis, self.shape)
def set_api(self):
self.ref_api = numpy_unflatten
self.paddle_api = paddle.unflatten
def setUp(self):
self.set_api()
self.set_args()
self.get_output()
self.places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def func_dygraph(self):
for place in self.places:
paddle.disable_static()
x = paddle.to_tensor(self.x, place=place)
if self.shape_is_tensor:
shape = paddle.to_tensor(self.shape)
else:
shape = self.shape
out = self.paddle_api(x=x, axis=self.axis, shape=shape)
np.testing.assert_allclose(out, self.output, rtol=1e-05)
def test_dygraph(self):
self.setUp()
self.func_dygraph()
def test_static(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.x.dtype
)
if self.shape_is_tensor:
shape = np.array(self.shape)
shape = paddle.static.data(
name='shape', shape=shape.shape, dtype=shape.dtype
)
else:
shape = self.shape
exe = paddle.static.Executor(place)
out = self.paddle_api(x=x, axis=self.axis, shape=shape)
fetches = exe.run(
paddle.static.default_main_program(),
feed={
"x": self.x,
"axis": self.axis,
"shape": self.shape,
},
fetch_list=[out],
)
np.testing.assert_allclose(fetches[0], self.output, rtol=1e-05)
# check the data type of the input x
class TestUnflattenInputInt16(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('int16')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputInt32(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('int32')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputInt64(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('int64')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputFloat16(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float16')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputFloat32(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputFloat64(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float64')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenInputbool(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('bool')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
# check the data type and edge cases of shape
class TestUnflattenShapeList1(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = [2, 2]
self.shape_is_tensor = False
class TestUnflattenShapeList2(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = -1
self.shape = [-1, 2]
self.shape_is_tensor = False
class TestUnflattenShapeList3(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = [-1]
self.shape_is_tensor = False
class TestUnflattenTupleShape1(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = (2, 2)
self.shape_is_tensor = False
class TestUnflattenTupleShape2(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = (-1, 2)
self.shape_is_tensor = False
class TestUnflattenTupleShape3(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = (-1,)
self.shape_is_tensor = False
class TestUnflattenShapeTensorInt32(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 0
self.shape = tuple(np.array((-1, 4)).astype('int32'))
self.shape_is_tensor = True
# check the value of axis
class TestUnflattenAxis1(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = 1
self.shape = (2, 3)
self.shape_is_tensor = False
class TestUnflattenAxis2(TestUnflattenAPI):
def set_args(self):
self.x = np.random.rand(4, 6, 16).astype('float32')
self.axis = -1
self.shape = (2, 8)
self.shape_is_tensor = False
class TestLayer(unittest.TestCase):
def set_args(self):
self.x = np.random.randn(3, 4, 4, 5).astype('float32')
self.axis = 1
self.shape = [2, 2]
def setUp(self):
self.set_args()
self.places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def test_layer(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(
name="x", dtype=self.x.dtype, shape=self.x.shape
)
exe = paddle.static.Executor(place)
unflatten = paddle.nn.Unflatten(self.axis, self.shape)
out = unflatten(x)
static_ret = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x, "axis": self.axis, "shape": self.shape},
fetch_list=[out],
)[0]
for place in self.places:
paddle.disable_static()
x = paddle.to_tensor(self.x, dtype='float32', place=place)
unflatten = paddle.nn.Unflatten(self.axis, self.shape)
dy_ret_value = unflatten(self.x)
np.testing.assert_array_equal(static_ret, dy_ret_value)
class TestLayerName(unittest.TestCase):
def test_name(self):
self.x = np.random.randn(3, 4, 4, 5).astype('float32')
self.axis = 1
self.shape = [2, 2]
self.name = 'unflatten'
unflatten = paddle.nn.Unflatten(self.axis, self.shape, self.name)
_name = unflatten.extra_repr()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
...@@ -70,6 +70,7 @@ from .layer.common import Dropout3D # noqa: F401 ...@@ -70,6 +70,7 @@ from .layer.common import Dropout3D # noqa: F401
from .layer.common import AlphaDropout # noqa: F401 from .layer.common import AlphaDropout # noqa: F401
from .layer.common import Unfold # noqa: F401 from .layer.common import Unfold # noqa: F401
from .layer.common import Fold # noqa: F401 from .layer.common import Fold # noqa: F401
from .layer.common import Unflatten # noqa: F401
from .layer.pooling import AvgPool1D # noqa: F401 from .layer.pooling import AvgPool1D # noqa: F401
from .layer.pooling import AvgPool2D # noqa: F401 from .layer.pooling import AvgPool2D # noqa: F401
...@@ -338,4 +339,5 @@ __all__ = [ # noqa ...@@ -338,4 +339,5 @@ __all__ = [ # noqa
'TripletMarginLoss', 'TripletMarginLoss',
'SoftMarginLoss', 'SoftMarginLoss',
'GaussianNLLLoss', 'GaussianNLLLoss',
'Unflatten',
] ]
...@@ -45,7 +45,9 @@ from .common import Dropout3D # noqa: F401 ...@@ -45,7 +45,9 @@ from .common import Dropout3D # noqa: F401
from .common import AlphaDropout # noqa: F401 from .common import AlphaDropout # noqa: F401
from .common import UpsamplingBilinear2D # noqa: F401 from .common import UpsamplingBilinear2D # noqa: F401
from .common import UpsamplingNearest2D # noqa: F401 from .common import UpsamplingNearest2D # noqa: F401
from .common import Fold from .common import Fold # noqa: F401
from .common import Unflatten # noqa: F401
from .pooling import AvgPool1D # noqa: F401 from .pooling import AvgPool1D # noqa: F401
from .pooling import AvgPool2D # noqa: F401 from .pooling import AvgPool2D # noqa: F401
from .pooling import AvgPool3D # noqa: F401 from .pooling import AvgPool3D # noqa: F401
......
...@@ -1741,3 +1741,53 @@ class Flatten(Layer): ...@@ -1741,3 +1741,53 @@ class Flatten(Layer):
input, start_axis=self.start_axis, stop_axis=self.stop_axis input, start_axis=self.start_axis, stop_axis=self.stop_axis
) )
return out return out
class Unflatten(Layer):
"""
This interface is used to construct a callable object of the ``Unflatten`` class.
For more details, refer to code examples.
It a certain dimension of the input x Tensor into a desired shape.
Parameters:
axis (int): :attr:`axis` to be unflattened, specified as an index into `x.shape`.
shape (list|tuple|Tensor): Unflatten :attr:`shape` on the specified :attr:`axis`. At most one dimension of the target :attr:`shape` can be -1.
If the input :attr:`shape` does not contain -1 , the product of all elements in ``shape`` should be equal to ``x.shape[axis]``.
The data type is `int` . If :attr:`shape` is a list or tuple, the elements of it should be integers or Tensors with shape [].
If :attr:`shape` is an Tensor, it should be an 1-D Tensor.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
None
Examples:
.. code-block:: python
import paddle
x = paddle.randn(shape=[4, 6, 8])
shape = [2, 3]
axis = 1
unflatten = paddle.nn.Unflatten(axis, shape)
res = unflatten(x)
print(res.shape)
# [4, 2, 3, 8]
"""
def __init__(self, axis, shape, name=None):
super().__init__()
self.axis = axis
self.shape = shape
self.name = name
def forward(self, input):
out = paddle.unflatten(
input, axis=self.axis, shape=self.shape, name=self.name
)
return out
def extra_repr(self):
name_str = f', name={self.name}' if self.name else ''
return f'axis={self.axis}, shape={self.shape}{name_str}'
...@@ -135,6 +135,7 @@ from .manipulation import moveaxis # noqa: F401 ...@@ -135,6 +135,7 @@ from .manipulation import moveaxis # noqa: F401
from .manipulation import repeat_interleave # noqa: F401 from .manipulation import repeat_interleave # noqa: F401
from .manipulation import index_add # noqa: F401 from .manipulation import index_add # noqa: F401
from .manipulation import index_add_ # noqa: F401 from .manipulation import index_add_ # noqa: F401
from .manipulation import unflatten # noqa: F401
from .math import abs # noqa: F401 from .math import abs # noqa: F401
from .math import acos # noqa: F401 from .math import acos # noqa: F401
from .math import asin # noqa: F401 from .math import asin # noqa: F401
...@@ -544,6 +545,7 @@ tensor_method_func = [ # noqa ...@@ -544,6 +545,7 @@ tensor_method_func = [ # noqa
'sigmoid_', 'sigmoid_',
'vander', 'vander',
'nextafter', 'nextafter',
'unflatten',
] ]
# this list used in math_op_patch.py for magic_method bind # this list used in math_op_patch.py for magic_method bind
......
...@@ -4795,6 +4795,75 @@ def index_add_(x, index, axis, value, name=None): ...@@ -4795,6 +4795,75 @@ def index_add_(x, index, axis, value, name=None):
return _C_ops.index_add_(x, index, value, axis) return _C_ops.index_add_(x, index, value, axis)
def unflatten(x, axis, shape, name=None):
"""
Expand a certain dimension of the input x Tensor into a desired shape.
Args:
x (Tensor) : An N-D Tensor. The data type is float16, float32, float64, int16, int32, int64, bool, uint16.
axis (int): :attr:`axis` to be unflattened, specified as an index into `x.shape`.
shape (list|tuple|Tensor): Unflatten :attr:`shape` on the specified :attr:`axis`. At most one dimension of the target :attr:`shape` can be -1.
If the input :attr:`shape` does not contain -1 , the product of all elements in ``shape`` should be equal to ``x.shape[axis]``.
The data type is `int` . If :attr:`shape` is a list or tuple, the elements of it should be integers or Tensors with shape [].
If :attr:`shape` is an Tensor, it should be an 1-D Tensor.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor, return the unflatten tensor of :attr:`x`.
Examples:
.. code-block:: python
import paddle
x = paddle.randn(shape=[4, 6, 8])
shape = [2, 3]
axis = 1
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [4, 2, 3, 8]
x = paddle.randn(shape=[4, 6, 8])
shape = (-1, 2)
axis = -1
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [4, 6, 4, 2]
x = paddle.randn(shape=[4, 6, 8])
shape = paddle.to_tensor([2, 2])
axis = 0
res = paddle.unflatten(x, axis, shape)
print(res.shape)
# [2, 2, 6, 8]
"""
# determine whether the input axis is valid.
axis = non_negative_axis(x, axis)
if isinstance(shape, (list, tuple)):
new_shape = (
list(x.shape[:axis]) + list(shape) + list(x.shape[axis + 1 :])
)
elif isinstance(shape, Variable):
# The data type returned by `paddle.shape` is only 'int32'.
new_shape = paddle.concat(
[
paddle.shape(x)[:axis],
paddle.cast(shape, 'int32'),
paddle.shape(x)[axis + 1 :],
]
)
else:
raise TypeError(
"The data type of x should be one of ['List', 'Tuple', 'Tensor'], but got {}".format(
type(shape)
)
)
x = x.reshape(new_shape)
return x
# TODO(dev): We need avoid implementing it by this way. # TODO(dev): We need avoid implementing it by this way.
__METHODS = { __METHODS = {
'fill_': fill_, 'fill_': fill_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册