未验证 提交 5258d53d 编写于 作者: L Leo Chen 提交者: GitHub

refine unsqueeze, test=develop (#25470)

* refine unsqueeze, test=develop

* update unsqueeze, test=develop

* refine unsqueeze, test=develop

* refine unsqueeze, test=develop

* update

* remove None, test=develop

* follow comments

* support bool

* update doc

* follow comments

* merge develop
上级 0dc485e6
......@@ -304,6 +304,7 @@ REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
REGISTER_OP_CPU_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -311,12 +312,14 @@ REGISTER_OP_CPU_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -324,6 +327,7 @@ REGISTER_OP_CPU_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -29,6 +30,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -36,6 +38,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -44,6 +47,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -6200,7 +6200,7 @@ def squeeze(input, axes, name=None):
Out.shape = [1,3,5]
Args:
input (Variable): The input Tensor. Support data type: float16, float32, float64, int8, int32, int64.
input (Variable): The input Tensor. Supported data type: float32, float64, bool, int8, int32, int64.
axes (list): One integer or List of integers, indicating the dimensions to be squeezed.
Axes range is :math:`[-rank(input), rank(input))`.
If axes is negative, :math:`axes=axes+rank(input)`.
......@@ -6226,8 +6226,9 @@ def squeeze(input, axes, name=None):
helper = LayerHelper("squeeze", **locals())
check_variable_and_dtype(
input, 'input',
['float16', 'float32', 'float64', 'int8', 'int32', 'int64'], 'squeeze')
check_type(axes, 'axes', (list, tuple), 'squeeze')
['float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64'],
'squeeze')
check_type(axes, 'axis/axes', (list, tuple), 'squeeze')
out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
......@@ -6254,12 +6255,12 @@ def unsqueeze(input, axes, name=None):
then Unsqueezed tensor with axes=[0, 4] has shape [1, 3, 4, 5, 1].
Args:
input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32.
input (Variable): The input Tensor to be unsqueezed. Supported data type: float32, float64, bool, int8, int32, int64.
axes (int|list|tuple|Variable): Indicates the dimensions to be inserted. The data type is ``int32`` . If ``axes`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``axes`` is an Variable, it should be an 1-D Tensor .
name (str|None): Name for this layer.
Returns:
Variable: Output unsqueezed Tensor, with data type being float32, float64, int32, int64.
Variable: Unsqueezed Tensor, with the same data type as input.
Examples:
.. code-block:: python
......@@ -6269,10 +6270,15 @@ def unsqueeze(input, axes, name=None):
y = fluid.layers.unsqueeze(input=x, axes=[1])
"""
if not isinstance(axes, (int, list, tuple, Variable)):
raise TypeError(
"The type of 'axes' in unsqueeze must be int, list, tuple or Variable, but "
"received %s." % (type(axes)))
if in_dygraph_mode():
out, _ = core.ops.unsqueeze2(input, 'axes', axes)
return out
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(
input, 'input',
['float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64'],
'unsqueeze')
helper = LayerHelper("unsqueeze2", **locals())
inputs = {"X": input}
attrs = {}
......@@ -9966,7 +9972,7 @@ def stack(x, axis=0, name=None):
must be the same. Supposing input is N dims
Tensors :math:`[d_0, d_1, ..., d_{n-1}]`, the output is N+1 dims
Tensor :math:`[d_0, d_1, d_{axis-1}, len(x), d_{axis}, ..., d_{n-1}]`.
Support data types: float32, float64, int32, int64.
Supported data types: float32, float64, int32, int64.
axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is :math:`[-(R+1), R+1)`.
R is the first tensor of inputs. If ``axis`` < 0, :math:`axis=axis+rank(x[0])+1`.
The default value of axis is 0.
......
......@@ -81,7 +81,7 @@ class API_TestUnsqueeze(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[-1, 10], dtype='float64')
result_squeeze = paddle.unsqueeze(data1, axes=[1])
result_squeeze = paddle.unsqueeze(data1, axis=[1])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
......@@ -98,7 +98,7 @@ class TestUnsqueezeOpError(unittest.TestCase):
def test_axes_type():
x6 = fluid.layers.data(
shape=[-1, 10], dtype='float16', name='x3')
paddle.unsqueeze(x6, axes=3.2)
paddle.unsqueeze(x6, axis=3.2)
self.assertRaises(TypeError, test_axes_type)
......@@ -108,7 +108,7 @@ class API_TestUnsqueeze2(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data('data1', shape=[-1, 10], dtype='float64')
data2 = fluid.data('data2', shape=[1], dtype='int32')
result_squeeze = paddle.unsqueeze(data1, axes=data2)
result_squeeze = paddle.unsqueeze(data1, axis=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
......@@ -125,7 +125,7 @@ class API_TestUnsqueeze3(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data('data1', shape=[-1, 10], dtype='float64')
data2 = fluid.data('data2', shape=[1], dtype='int32')
result_squeeze = paddle.unsqueeze(data1, axes=[data2, 3])
result_squeeze = paddle.unsqueeze(data1, axis=[data2, 3])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input1 = np.random.random([5, 1, 10, 1]).astype('float64')
......@@ -143,7 +143,7 @@ class API_TestDyUnsqueeze(unittest.TestCase):
input_1 = np.random.random([5, 1, 10]).astype("int32")
input1 = np.squeeze(input_1, axis=1)
input = fluid.dygraph.to_variable(input_1)
output = paddle.unsqueeze(input, axes=[1])
output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy()
self.assertTrue(np.allclose(input1, out_np))
......@@ -154,7 +154,7 @@ class API_TestDyUnsqueeze2(unittest.TestCase):
input_1 = np.random.random([5, 1, 10]).astype("int32")
input1 = np.squeeze(input_1, axis=1)
input = fluid.dygraph.to_variable(input_1)
output = paddle.unsqueeze(input, axes=1)
output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy()
self.assertTrue(np.allclose(input1, out_np))
......
......@@ -42,11 +42,32 @@ from ..fluid import layers
import paddle
__all__ = [
'cast', 'concat', 'expand', 'expand_as', 'flatten', 'gather', 'gather_nd',
'reshape', 'reverse', 'scatter', 'scatter_nd_add', 'scatter_nd',
'shard_index', 'slice', 'split', 'squeeze', 'stack', 'strided_slice',
'transpose', 'unique', 'unique_with_counts', 'unsqueeze', 'unstack', 'flip',
'unbind', 'roll'
'cast',
'concat',
'expand',
'expand_as',
'flatten',
'gather',
'gather_nd',
'reshape',
'reverse',
'scatter',
'scatter_nd_add',
'scatter_nd',
'shard_index',
'slice',
'split',
'squeeze',
'stack',
'strided_slice',
'transpose',
'unique',
'unique_with_counts',
'unsqueeze',
'unstack',
'flip',
'unbind',
'roll',
]
......@@ -417,7 +438,7 @@ def stack(x, axis=0, name=None):
Args:
x (Tensor|list[Tensor]): Input ``x`` can be a single tensor, or a ``list`` of tensors.
If ``x`` is a ``list``, the Tensors in ``x``
must be of the same shape and dtype. Support data types: float32, float64, int32, int64.
must be of the same shape and dtype. Supported data types: float32, float64, int32, int64.
axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is ``[-(R+1), R+1)``,
where ``R`` is the number of dimensions of the first input tensor ``x[0]``.
If ``axis < 0``, ``axis = axis+R+1``. The default value of axis is 0.
......@@ -559,18 +580,19 @@ def squeeze(x, axis=None, name=None):
out.shape = [1, 3, 5]
Args:
input (Tensor): The input Tensor. Support data type: float32, float64, int8, int32, int64.
x (Tensor): The input Tensor. Supported data type: float32, float64, bool, int8, int32, int64.
axis (int|list|tuple, optional): An integer or list of integers, indicating the dimensions to be squeezed. Default is None.
The range of axis is :math:`[-ndim(input), ndim(input))`.
If axis is negative, :math:`axis = axis + ndim(input)`.
If axis is None, all the dimensions of input of size 1 will be removed.
The range of axis is :math:`[-ndim(x), ndim(x))`.
If axis is negative, :math:`axis = axis + ndim(x)`.
If axis is None, all the dimensions of x of size 1 will be removed.
name (str, optional): Please refer to :ref:`api_guide_Name`, Default None.
Returns:
Tensor: Output squeezed Tensor. Data type is same as input Tensor.
Tensor: Squeezed Tensor with the same data type as input Tensor.
Examples:
.. code-block:: python
import paddle
paddle.enable_imperative()
......@@ -590,87 +612,50 @@ def squeeze(x, axis=None, name=None):
return layers.squeeze(x, axis, name)
def unsqueeze(input, axes, out=None, name=None):
def unsqueeze(x, axis, name=None):
"""
:alias_main: paddle.unsqueeze
:alias: paddle.unsqueeze,paddle.tensor.unsqueeze,paddle.tensor.manipulation.unsqueeze
Insert single-dimensional entries to the shape of a Tensor. Takes one
required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
.. code-block:: text
:alias: paddle.unsqueeze, paddle.tensor.unsqueeze, paddle.tensor.manipulation.unsqueeze
Given a tensor such that tensor with shape [3, 4, 5],
then Unsqueezed tensor with axes=[0, 4] has shape [1, 3, 4, 5, 1].
Insert single-dimensional entries to the shape of input Tensor ``x``. Takes one
required argument axis, a dimension or list of dimensions that will be inserted.
Dimension indices in axis are as seen in the output tensor.
Args:
input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32.
axes (int|list|tuple|Variable): Indicates the dimensions to be inserted. The data type is ``int32`` . If ``axes`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``axes`` is an Variable, it should be an 1-D Tensor .
name (str|None): Name for this layer.
x (Tensor): The input Tensor to be unsqueezed. Supported data type: float32, float64, bool, int8, int32, int64.
axis (int|list|tuple|Tensor): Indicates the dimensions to be inserted. The data type is ``int32`` .
If ``axis`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``axis`` is a Tensor, it should be an 1-D Tensor .
If ``axis`` is negative, ``axis = axis + ndim(x) + 1``.
name (str|None): Name for this layer. Please refer to :ref:`api_guide_Name`, Default None.
Returns:
Variable: Output unsqueezed Tensor, with data type being float32, float64, int32, int64.
Tensor: Unsqueezed Tensor with the same data type as input Tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
with fluid.dygraph.guard():
input_1 = np.random.random([5, 10]).astype("int32")
# input is a variable which shape is [5, 10]
input = fluid.dygraph.to_variable(input_1)
paddle.enable_imperative()
x = paddle.rand([5, 10])
print(x.shape) # [5, 10]
out1 = paddle.unsqueeze(x, axis=0)
print(out1.shape) # [1, 5, 10]
out2 = paddle.unsqueeze(x, axis=[0, 2])
print(out2.shape) # [1, 5, 1, 10]
output = paddle.unsqueeze(input, axes=[1])
# output.shape [5, 1, 10]
axis = paddle.fluid.dygraph.to_variable([0, 1, 2])
out3 = paddle.unsqueeze(x, axis=axis)
print(out3.shape) # [1, 1, 1, 5, 10]
"""
if not isinstance(axes, (int, list, tuple, Variable)):
raise TypeError(
"The type of 'axes' in unsqueeze must be int, list, tuple or Variable, but "
"received %s." % (type(axes)))
helper = LayerHelper("unsqueeze2", **locals())
inputs = {"X": input}
attrs = {}
def _to_Variable_list(one_list):
Variable_list = []
for ele in one_list:
if isinstance(ele, Variable):
ele.stop_gradient = True
Variable_list.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out)
Variable_list.append(temp_out)
return Variable_list
if isinstance(axes, int):
axes = [axes]
if isinstance(axes, Variable):
axes.stop_gradient = True
inputs["AxesTensor"] = axes
elif isinstance(axes, (list, tuple)):
contain_var = not all(not isinstance(ele, Variable) for ele in axes)
if contain_var:
inputs["AxesTensorList"] = _to_Variable_list(axes)
else:
attrs["axes"] = axes
out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type="unsqueeze2",
inputs=inputs,
attrs=attrs,
outputs={"Out": out,
"XShape": x_shape})
if isinstance(axis, int):
axis = [axis]
return out
return layers.unsqueeze(x, axis, name)
def gather(input, index, overwrite=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册