未验证 提交 66c084b7 编写于 作者: Y Yuan Shuai 提交者: GitHub

Refine error message of transpose. fix en doc about unsqueeze, unstack,...

Refine error message of transpose. fix en doc about unsqueeze, unstack, transpose, multiplex. test=develop, test=document_preview, test=document_fix (#20594)
上级 b4cd5a11
......@@ -177,14 +177,14 @@ paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y'
paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e645c2f6c24cf076260d380df929e243'))
paddle.fluid.layers.warpctc (ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(0, False, None, None)), ('document', '79aaea078ddea57a82ed7906d71dedc7'))
paddle.fluid.layers.sequence_reshape (ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None), ('document', 'eeb1591cfc854c6ffdac77b376313c44'))
paddle.fluid.layers.transpose (ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '8e72db173d4c082e27cb11f31d8c9bfa'))
paddle.fluid.layers.transpose (ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ae5c346abc8a7d85fc3ebe2e1ba0f428'))
paddle.fluid.layers.im2sequence (ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)), ('document', 'fe352915a543cec434f74e9b32ac49da'))
paddle.fluid.layers.nce (ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)), ('document', '38297567127888e01542857839058d52'))
paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)), ('document', 'd4435a63d34203339831ee6a86ef9242'))
paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', '247de339879885526e7f4d271967088f'))
paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', '2b505ddaa309fd7b9be5445e41ca76d5'))
paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'a6477957b44907787b3c74157400b80c'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '8dba76e9b1521b4ab62e38608b6aa3f6'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '678de6d6d0c93da74189990b039daae8'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '87dd4b818f102bc1a780e1804c28bd38'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '7b3d14d6707d878923847ec617d7d521'))
......@@ -194,7 +194,7 @@ paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range
paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', 'd016c137beb9a4528b7378b437d00151'))
paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', 'd7a6d59e464a7ef1184eb6caefeb49f1'))
paddle.fluid.layers.squeeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '61d8be0c5af7b9313b0bdb8697c7d4de'))
paddle.fluid.layers.unsqueeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b9bd3129d36a70e7c4385df51ff71c62'))
paddle.fluid.layers.unsqueeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd1f4c0d1284315066210ff0b33adf747'))
paddle.fluid.layers.lod_reset (ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'f1f04ae9bdcf8f3adc0658db6904aa0e'))
paddle.fluid.layers.lod_append (ArgSpec(args=['x', 'level'], varargs=None, keywords=None, defaults=None), ('document', '37663c7c179e920838a250ea0e28d909'))
paddle.fluid.layers.lrn (ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None)), ('document', 'fa565b65fb98d3ca82361c79f41b06b2'))
......@@ -238,7 +238,7 @@ paddle.fluid.layers.flatten (ArgSpec(args=['x', 'axis', 'name'], varargs=None, k
paddle.fluid.layers.sequence_mask (ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)), ('document', '6c3f916921b24edaad220f1fcbf039de'))
paddle.fluid.layers.stack (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '666d995b36e9f287d77f09189370fb3a'))
paddle.fluid.layers.pad2d (ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None)), ('document', '4e277f064c1765f77f946da194626ca1'))
paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b0c4ca08d4eb295189e1b107c920d093'))
paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b9d343d8961dfa30d65b1e59d86f53cd'))
paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1'))
paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9'))
paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '4496682f302007019e458a2f30d8a7c3'))
......
......@@ -39,17 +39,23 @@ class TransposeOp : public framework::OperatorWithKernel {
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank, axis_size,
"The input tensor's rank(%d) "
"should be equal to the axis's size(%d)",
"ShapeError: The input tensor's dimension "
"should be equal to the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank, axis_size);
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
"Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size");
"ValueError: Each element of Attribute axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received axis[%d] is %d, axis_size is %d, "
"count[axis[%d]] is %d",
i, axis[i], axis_size, i, count[axis[i]]);
}
framework::DDim out_dims(x_dims);
......
......@@ -7720,40 +7720,82 @@ def hsigmoid(input,
def transpose(x, perm, name=None):
"""
Permute the dimensions of `input` according to `perm`.
Permute the data dimensions of `input` according to `perm`.
The `i`-th dimension of the returned tensor will correspond to the
perm[i]-th dimension of `input`.
Args:
x (Variable): The input Tensor.
perm (list): A permutation of the dimensions of `input`.
x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32.
perm (list): Permute the input accoring to the data of perm.
name (str): The name of this layer. It is optional.
Returns:
Variable: A transposed Tensor.
Variable: A transposed n-D Tensor, with data type being float32, float64, int32, int64.
For Example:
.. code-block:: text
x = [[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12]]
[[13 14 15 16] [17 18 19 20] [21 22 23 24]]]
shape(x) = [2,3,4]
# Example 1
perm0 = [1,0,2]
y_perm0 = [[[ 1 2 3 4] [13 14 15 16]]
[[ 5 6 7 8] [17 18 19 20]]
[[ 9 10 11 12] [21 22 23 24]]]
shape(y_perm0) = [3,2,4]
# Example 2
perm1 = [2,1,0]
y_perm1 = [[[ 1 13] [ 5 17] [ 9 21]]
[[ 2 14] [ 6 18] [10 22]]
[[ 3 15] [ 7 19] [11 23]]
[[ 4 16] [ 8 20] [12 24]]]
shape(y_perm1) = [4,3,2]
Examples:
.. code-block:: python
# use append_batch_size=False to avoid prepending extra
# batch size in shape
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[5, 10, 15],
x = fluid.layers.data(name='x', shape=[2, 3, 4],
dtype='float32', append_batch_size=False)
x_transposed = fluid.layers.transpose(x, perm=[1, 0, 2])
"""
print x_transposed.shape
#(3L, 2L, 4L)
"""
if not isinstance(x, Variable):
raise TypeError(
"The type of Input(x) in transpose must be Variable, but received %s"
% (type(x)))
if convert_dtype(x.dtype) not in [
"float16", "float32", "float64", "int32", "int64"
]:
raise TypeError(
"The data type of Input(x) in transpose must be one of [float16, float32, float64, int32, int64], but received %s."
% (convert_dtype(x.dtype)))
if not isinstance(perm, list):
raise TypeError(
"The type of Input(perm) in transpose must be list, but received %s"
% (type(perm)))
if len(perm) != len(x.shape):
raise ValueError(
"Input(perm) is the permutation of dimensions of Input(input). "
"Its length should be equal to Input(input)'s rank.")
"Input(perm) is the permutation of dimensions of Input(x), "
"its length should be equal to dimensions of Input(x), "
"but received dimension of Input(x) is %s, "
"the length of Input(perm) is %s." % (len(x.shape), len(perm)))
for idx, dim in enumerate(perm):
if dim >= len(x.shape):
raise ValueError(
"Each element in perm should be less than x's rank. "
"%d-th element in perm is %d which accesses x's rank %d." %
(idx, perm[idx], len(x.shape)))
"Each element in Input(perm) should be less than Input(x)'s dimension, "
"but %d-th element in Input(perm) is %d which exceeds Input(x)'s "
"dimension %d." % (idx, perm[idx], len(x.shape)))
helper = LayerHelper('transpose', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......@@ -7952,58 +7994,61 @@ def row_conv(input, future_context_size, param_attr=None, act=None):
@templatedoc()
def multiplex(inputs, index):
"""
${comment}
For Example:
Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor.
.. code-block:: text
If the input of this OP contains :math:`m` Tensors, where :math:`I_{i}` means the i-th input Tensor, :math:`i` between :math:`[0,m)` .
case 1:
And :math:`O` means the output, where :math:`O[i]` means the i-th row of the output, then the output satisfies that :math:`O[i] = I_{index[i]}[i]` .
Given:
For Example:
X = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]],
[[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]],
[[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]],
[[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]]
.. code-block:: text
index = [3,0,1,2]
Given:
out:[[3 0 3 4] // X[3,0] (3 = index[i], 0 = i); i=0
[0 1 3 4] // X[0,1] (0 = index[i], 1 = i); i=1
[1 2 4 2] // X[1,2] (0 = index[i], 2 = i); i=2
[2 3 3 4]] // X[2,3] (0 = index[i], 3 = i); i=3
inputs = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]],
[[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]],
[[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]],
[[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]]
case 2:
index = [[3],[0],[1],[2]]
Given:
out = [[3,0,3,4], # out[0] = inputs[index[0]][0] = inputs[3][0] = [3,0,3,4]
[0,1,3,4], # out[1] = inputs[index[1]][1] = inputs[0][1] = [0,1,3,4]
[1,2,4,2], # out[2] = inputs[index[2]][2] = inputs[1][2] = [1,2,4,2]
[2,3,3,4]] # out[3] = inputs[index[3]][3] = inputs[2][3] = [2,3,3,4]
X = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]],
[[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]]]
index = [1,0]
Args:
inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2.
index (Variable): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors.
out:[[1 0 3 4] // X[1,0] (3 = index[0], 0 = i); i=1
[0 1 3 4] // X[0,1] (0 = index[1], 1 = i); i=2
[0 2 4 4] // X[0,2] (0 = 0, 2 = i); i=3
[0 3 3 4]] // X[0,3] (0 = 0, 3 = i); i=4
Returns:
Variable(Tensor): Output of multiplex OP, with data type being float32, float64, int32, int64.
Examples:
.. code-block:: python
.. code-block:: python
import paddle.fluid as fluid
x1 = fluid.layers.data(name='x1', shape=[4], dtype='float32')
x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32')
index = fluid.layers.data(name='index', shape=[1], dtype='int32')
out = fluid.layers.multiplex(inputs=[x1, x2], index=index)
import paddle.fluid as fluid
import numpy as np
Args:
inputs (list): ${x_comment}.
index (${ids_type}): ${ids_comment}.
x1 = fluid.data(name='x1', shape=[None, 2], dtype='float32')
x2 = fluid.data(name='x2', shape=[None, 2], dtype='float32')
index = fluid.data(name='index', shape=[None, 1], dtype='int32')
out = fluid.layers.multiplex(inputs=[x1, x2], index=index)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img1 = np.array([[1, 2], [3, 4]]).astype(np.float32)
img2 = np.array([[5, 6], [7, 8]]).astype(np.float32)
index = np.array([[1], [0]]).astype(np.int32)
res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out])
print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)]
Returns:
${out_comment}.
"""
helper = LayerHelper('multiplex', **locals())
......@@ -8774,7 +8819,7 @@ def squeeze(input, axes, name=None):
def unsqueeze(input, axes, name=None):
"""
Insert single-dimensional entries to the shape of a tensor. Takes one
Insert single-dimensional entries to the shape of a Tensor. Takes one
required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
......@@ -8786,12 +8831,12 @@ def unsqueeze(input, axes, name=None):
then Unsqueezed tensor with axes=[0, 4] has shape [1, 3, 4, 5, 1].
Args:
input (Variable): The input variable to be unsqueezed.
input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32.
axes (list): List of integers, indicating the dimensions to be inserted.
name (str|None): Name for this layer.
Returns:
Variable: Output unsqueezed variable.
Variable: Output unsqueezed Tensor, with data type being float32, float64, int32, int64.
Examples:
.. code-block:: python
......@@ -8799,6 +8844,7 @@ def unsqueeze(input, axes, name=None):
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[5, 10])
y = fluid.layers.unsqueeze(input=x, axes=[1])
"""
helper = LayerHelper("unsqueeze", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
......@@ -12518,7 +12564,7 @@ def unstack(x, axis=0, num=None):
"""
**UnStack Layer**
This layer unstacks input :code:`x` into several tensors along axis.
This layer unstacks input Tensor :code:`x` into several Tensors along :code:`axis`.
If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`.
If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`,
......@@ -12526,21 +12572,24 @@ def unstack(x, axis=0, num=None):
raised.
Args:
x (Variable): Input variable.
x (Variable): Input Tensor. It is a N-D Tensors of data types float32, float64, int32, int64.
axis (int): The axis along which the input is unstacked.
num (int|None): The number of output variables.
Returns:
list(Variable): The unstacked variables.
list(Variable): The unstacked Tensors list. The list elements are N-D Tensors of data types float32, float64, int32, int64.
Raises:
ValueError: If x.shape[axis] <= 0 or axis is not in range [-D, D).
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[5, 10], dtype='float32')
y = fluid.layers.unstack(x, axis=1)
"""
x = fluid.layers.data(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5]
y = fluid.layers.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]
"""
helper = LayerHelper('unstack', **locals())
if num is None:
if axis is None or x.shape[axis] <= 0:
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestTransposeOp(OpTest):
......@@ -78,5 +80,44 @@ class TestCase4(TestTransposeOp):
self.axis = (4, 2, 3, 1, 0, 5)
class TestTransposeOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float32')
def test_x_Variable_check():
# the Input(x)'s type must be Variable
fluid.layers.transpose("not_variable", perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_Variable_check)
def test_x_dtype_check():
# the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64]
x1 = fluid.layers.data(
name='x1', shape=[10, 5, 3], dtype='bool')
fluid.layers.transpose(x1, perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_dtype_check)
def test_perm_list_check():
# Input(perm)'s type must be list
fluid.layers.transpose(x, perm="[1, 0, 2]")
self.assertRaises(TypeError, test_perm_list_check)
def test_perm_length_and_x_dim_check():
# Input(perm) is the permutation of dimensions of Input(input)
# its length should be equal to dimensions of Input(input)
fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4])
self.assertRaises(ValueError, test_perm_length_and_x_dim_check)
def test_each_elem_value_check():
# Each element in Input(perm) should be less than Input(x)'s dimension
fluid.layers.transpose(x, perm=[3, 5, 7])
self.assertRaises(ValueError, test_each_elem_value_check)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册