diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b93ce0709973ea555e7342832d5a06628025afd7..4e8ad010a2a624b97d1ac06271c08bea16c9380c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -177,14 +177,14 @@ paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y' paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e645c2f6c24cf076260d380df929e243')) paddle.fluid.layers.warpctc (ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(0, False, None, None)), ('document', '79aaea078ddea57a82ed7906d71dedc7')) paddle.fluid.layers.sequence_reshape (ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None), ('document', 'eeb1591cfc854c6ffdac77b376313c44')) -paddle.fluid.layers.transpose (ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '8e72db173d4c082e27cb11f31d8c9bfa')) +paddle.fluid.layers.transpose (ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ae5c346abc8a7d85fc3ebe2e1ba0f428')) paddle.fluid.layers.im2sequence (ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)), ('document', 'fe352915a543cec434f74e9b32ac49da')) paddle.fluid.layers.nce (ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)), ('document', '38297567127888e01542857839058d52')) paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)), ('document', 'd4435a63d34203339831ee6a86ef9242')) paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', '247de339879885526e7f4d271967088f')) paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', '2b505ddaa309fd7b9be5445e41ca76d5')) paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'a6477957b44907787b3c74157400b80c')) -paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23')) +paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '8dba76e9b1521b4ab62e38608b6aa3f6')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '678de6d6d0c93da74189990b039daae8')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '87dd4b818f102bc1a780e1804c28bd38')) paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '7b3d14d6707d878923847ec617d7d521')) @@ -194,7 +194,7 @@ paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', 'd016c137beb9a4528b7378b437d00151')) paddle.fluid.layers.reshape (ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', 'd7a6d59e464a7ef1184eb6caefeb49f1')) paddle.fluid.layers.squeeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '61d8be0c5af7b9313b0bdb8697c7d4de')) -paddle.fluid.layers.unsqueeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b9bd3129d36a70e7c4385df51ff71c62')) +paddle.fluid.layers.unsqueeze (ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd1f4c0d1284315066210ff0b33adf747')) paddle.fluid.layers.lod_reset (ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)), ('document', 'f1f04ae9bdcf8f3adc0658db6904aa0e')) paddle.fluid.layers.lod_append (ArgSpec(args=['x', 'level'], varargs=None, keywords=None, defaults=None), ('document', '37663c7c179e920838a250ea0e28d909')) paddle.fluid.layers.lrn (ArgSpec(args=['input', 'n', 'k', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(5, 1.0, 0.0001, 0.75, None)), ('document', 'fa565b65fb98d3ca82361c79f41b06b2')) @@ -238,7 +238,7 @@ paddle.fluid.layers.flatten (ArgSpec(args=['x', 'axis', 'name'], varargs=None, k paddle.fluid.layers.sequence_mask (ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)), ('document', '6c3f916921b24edaad220f1fcbf039de')) paddle.fluid.layers.stack (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '666d995b36e9f287d77f09189370fb3a')) paddle.fluid.layers.pad2d (ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None)), ('document', '4e277f064c1765f77f946da194626ca1')) -paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b0c4ca08d4eb295189e1b107c920d093')) +paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b9d343d8961dfa30d65b1e59d86f53cd')) paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1')) paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9')) paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '4496682f302007019e458a2f30d8a7c3')) diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 47840d71a3d1fddced1fc9f37174915a89f17aa4..226aad03845d7629d7be556b394ebe06abba44d5 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -39,17 +39,23 @@ class TransposeOp : public framework::OperatorWithKernel { size_t axis_size = axis.size(); PADDLE_ENFORCE_EQ(x_rank, axis_size, - "The input tensor's rank(%d) " - "should be equal to the axis's size(%d)", + "ShapeError: The input tensor's dimension " + "should be equal to the axis's size. " + "But received input tensor's dimension is %d, " + "axis's size is %d", x_rank, axis_size); std::vector count(axis_size, 0); for (size_t i = 0; i < axis_size; i++) { PADDLE_ENFORCE( axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, - "Each element of Attribute axis should be a unique value " - "range from 0 to (dims - 1), " - "where the dims is the axis's size"); + "ValueError: Each element of Attribute axis should " + "be a unique value range from 0 to (dims - 1), " + "where the dims is the axis's size, " + "unique value means this axis value can appear only once. " + "But received axis[%d] is %d, axis_size is %d, " + "count[axis[%d]] is %d", + i, axis[i], axis_size, i, count[axis[i]]); } framework::DDim out_dims(x_dims); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3ae2967ac32bbe1d917af44b53b3917279eb1c7e..5b68959ffddbc4b94b5423237eda9bce3db10420 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7720,40 +7720,82 @@ def hsigmoid(input, def transpose(x, perm, name=None): """ - Permute the dimensions of `input` according to `perm`. + Permute the data dimensions of `input` according to `perm`. The `i`-th dimension of the returned tensor will correspond to the perm[i]-th dimension of `input`. Args: - x (Variable): The input Tensor. - perm (list): A permutation of the dimensions of `input`. + x (Variable): The input Tensor. It is a N-D Tensor of data types float32, float64, int32. + perm (list): Permute the input accoring to the data of perm. name (str): The name of this layer. It is optional. Returns: - Variable: A transposed Tensor. + Variable: A transposed n-D Tensor, with data type being float32, float64, int32, int64. + + For Example: + + .. code-block:: text + + x = [[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12]] + [[13 14 15 16] [17 18 19 20] [21 22 23 24]]] + shape(x) = [2,3,4] + + # Example 1 + perm0 = [1,0,2] + y_perm0 = [[[ 1 2 3 4] [13 14 15 16]] + [[ 5 6 7 8] [17 18 19 20]] + [[ 9 10 11 12] [21 22 23 24]]] + shape(y_perm0) = [3,2,4] + + # Example 2 + perm1 = [2,1,0] + y_perm1 = [[[ 1 13] [ 5 17] [ 9 21]] + [[ 2 14] [ 6 18] [10 22]] + [[ 3 15] [ 7 19] [11 23]] + [[ 4 16] [ 8 20] [12 24]]] + shape(y_perm1) = [4,3,2] Examples: + .. code-block:: python # use append_batch_size=False to avoid prepending extra # batch size in shape import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[5, 10, 15], + x = fluid.layers.data(name='x', shape=[2, 3, 4], dtype='float32', append_batch_size=False) x_transposed = fluid.layers.transpose(x, perm=[1, 0, 2]) - """ + print x_transposed.shape + #(3L, 2L, 4L) + """ + if not isinstance(x, Variable): + raise TypeError( + "The type of Input(x) in transpose must be Variable, but received %s" + % (type(x))) + if convert_dtype(x.dtype) not in [ + "float16", "float32", "float64", "int32", "int64" + ]: + raise TypeError( + "The data type of Input(x) in transpose must be one of [float16, float32, float64, int32, int64], but received %s." + % (convert_dtype(x.dtype))) + if not isinstance(perm, list): + raise TypeError( + "The type of Input(perm) in transpose must be list, but received %s" + % (type(perm))) if len(perm) != len(x.shape): raise ValueError( - "Input(perm) is the permutation of dimensions of Input(input). " - "Its length should be equal to Input(input)'s rank.") + "Input(perm) is the permutation of dimensions of Input(x), " + "its length should be equal to dimensions of Input(x), " + "but received dimension of Input(x) is %s, " + "the length of Input(perm) is %s." % (len(x.shape), len(perm))) for idx, dim in enumerate(perm): if dim >= len(x.shape): raise ValueError( - "Each element in perm should be less than x's rank. " - "%d-th element in perm is %d which accesses x's rank %d." % - (idx, perm[idx], len(x.shape))) + "Each element in Input(perm) should be less than Input(x)'s dimension, " + "but %d-th element in Input(perm) is %d which exceeds Input(x)'s " + "dimension %d." % (idx, perm[idx], len(x.shape))) helper = LayerHelper('transpose', **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -7952,58 +7994,61 @@ def row_conv(input, future_context_size, param_attr=None, act=None): @templatedoc() def multiplex(inputs, index): """ - ${comment} - For Example: + Based on the given index parameter, the OP selects a specific row from each input Tensor to construct the output Tensor. - .. code-block:: text + If the input of this OP contains :math:`m` Tensors, where :math:`I_{i}` means the i-th input Tensor, :math:`i` between :math:`[0,m)` . - case 1: + And :math:`O` means the output, where :math:`O[i]` means the i-th row of the output, then the output satisfies that :math:`O[i] = I_{index[i]}[i]` . - Given: + For Example: - X = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]], - [[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]], - [[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]], - [[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]] + .. code-block:: text - index = [3,0,1,2] + Given: - out:[[3 0 3 4] // X[3,0] (3 = index[i], 0 = i); i=0 - [0 1 3 4] // X[0,1] (0 = index[i], 1 = i); i=1 - [1 2 4 2] // X[1,2] (0 = index[i], 2 = i); i=2 - [2 3 3 4]] // X[2,3] (0 = index[i], 3 = i); i=3 + inputs = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]], + [[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]], + [[2,0,3,4], [2,1,7,8], [2,2,4,2], [2,3,3,4]], + [[3,0,3,4], [3,1,7,8], [3,2,4,2], [3,3,3,4]]] - case 2: + index = [[3],[0],[1],[2]] - Given: + out = [[3,0,3,4], # out[0] = inputs[index[0]][0] = inputs[3][0] = [3,0,3,4] + [0,1,3,4], # out[1] = inputs[index[1]][1] = inputs[0][1] = [0,1,3,4] + [1,2,4,2], # out[2] = inputs[index[2]][2] = inputs[1][2] = [1,2,4,2] + [2,3,3,4]] # out[3] = inputs[index[3]][3] = inputs[2][3] = [2,3,3,4] - X = [[[0,0,3,4], [0,1,3,4], [0,2,4,4], [0,3,3,4]], - [[1,0,3,4], [1,1,7,8], [1,2,4,2], [1,3,3,4]]] - index = [1,0] + Args: + inputs (list): The input Tensor list. The list elements are N-D Tensors of data types float32, float64, int32, int64. All input Tensor shapes should be the same and rank must be at least 2. + index (Variable): Used to select some rows in the input Tensor to construct an index of the output Tensor. It is a 2-D Tensor with data type int32 or int64 and shape [M, 1], where M is the number of input Tensors. - out:[[1 0 3 4] // X[1,0] (3 = index[0], 0 = i); i=1 - [0 1 3 4] // X[0,1] (0 = index[1], 1 = i); i=2 - [0 2 4 4] // X[0,2] (0 = 0, 2 = i); i=3 - [0 3 3 4]] // X[0,3] (0 = 0, 3 = i); i=4 + Returns: + Variable(Tensor): Output of multiplex OP, with data type being float32, float64, int32, int64. Examples: - .. code-block:: python + .. code-block:: python - import paddle.fluid as fluid - x1 = fluid.layers.data(name='x1', shape=[4], dtype='float32') - x2 = fluid.layers.data(name='x2', shape=[4], dtype='float32') - index = fluid.layers.data(name='index', shape=[1], dtype='int32') - out = fluid.layers.multiplex(inputs=[x1, x2], index=index) + import paddle.fluid as fluid + import numpy as np - Args: - inputs (list): ${x_comment}. - index (${ids_type}): ${ids_comment}. + x1 = fluid.data(name='x1', shape=[None, 2], dtype='float32') + x2 = fluid.data(name='x2', shape=[None, 2], dtype='float32') + index = fluid.data(name='index', shape=[None, 1], dtype='int32') + out = fluid.layers.multiplex(inputs=[x1, x2], index=index) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + img1 = np.array([[1, 2], [3, 4]]).astype(np.float32) + img2 = np.array([[5, 6], [7, 8]]).astype(np.float32) + index = np.array([[1], [0]]).astype(np.int32) + + res = exe.run(fluid.default_main_program(), feed={'x1':img1, 'x2':img2, 'index':index}, fetch_list=[out]) + print(res) # [array([[5., 6.], [3., 4.]], dtype=float32)] - Returns: - ${out_comment}. """ helper = LayerHelper('multiplex', **locals()) @@ -8774,7 +8819,7 @@ def squeeze(input, axes, name=None): def unsqueeze(input, axes, name=None): """ - Insert single-dimensional entries to the shape of a tensor. Takes one + Insert single-dimensional entries to the shape of a Tensor. Takes one required argument axes, a list of dimensions that will be inserted. Dimension indices in axes are as seen in the output tensor. @@ -8786,12 +8831,12 @@ def unsqueeze(input, axes, name=None): then Unsqueezed tensor with axes=[0, 4] has shape [1, 3, 4, 5, 1]. Args: - input (Variable): The input variable to be unsqueezed. + input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32. axes (list): List of integers, indicating the dimensions to be inserted. name (str|None): Name for this layer. Returns: - Variable: Output unsqueezed variable. + Variable: Output unsqueezed Tensor, with data type being float32, float64, int32, int64. Examples: .. code-block:: python @@ -8799,6 +8844,7 @@ def unsqueeze(input, axes, name=None): import paddle.fluid as fluid x = fluid.layers.data(name='x', shape=[5, 10]) y = fluid.layers.unsqueeze(input=x, axes=[1]) + """ helper = LayerHelper("unsqueeze", **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) @@ -12518,7 +12564,7 @@ def unstack(x, axis=0, num=None): """ **UnStack Layer** - This layer unstacks input :code:`x` into several tensors along axis. + This layer unstacks input Tensor :code:`x` into several Tensors along :code:`axis`. If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`. If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`, @@ -12526,21 +12572,24 @@ def unstack(x, axis=0, num=None): raised. Args: - x (Variable): Input variable. + x (Variable): Input Tensor. It is a N-D Tensors of data types float32, float64, int32, int64. axis (int): The axis along which the input is unstacked. num (int|None): The number of output variables. Returns: - list(Variable): The unstacked variables. + list(Variable): The unstacked Tensors list. The list elements are N-D Tensors of data types float32, float64, int32, int64. + + Raises: + ValueError: If x.shape[axis] <= 0 or axis is not in range [-D, D). Examples: .. code-block:: python import paddle.fluid as fluid - x = fluid.layers.data(name='x', shape=[5, 10], dtype='float32') - y = fluid.layers.unstack(x, axis=1) - """ + x = fluid.layers.data(name='x', shape=[2, 3, 5], dtype='float32') # create a tensor with shape=[2, 3, 5] + y = fluid.layers.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] + """ helper = LayerHelper('unstack', **locals()) if num is None: if axis is None or x.shape[axis] <= 0: diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index a38540a7240636415ef4703609c5a3e8e83ed1da..37d153aaad884796a9339692ce76811522ff2419 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -17,6 +17,8 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard class TestTransposeOp(OpTest): @@ -78,5 +80,44 @@ class TestCase4(TestTransposeOp): self.axis = (4, 2, 3, 1, 0, 5) +class TestTransposeOpError(OpTest): + def test_errors(self): + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float32') + + def test_x_Variable_check(): + # the Input(x)'s type must be Variable + fluid.layers.transpose("not_variable", perm=[1, 0, 2]) + + self.assertRaises(TypeError, test_x_Variable_check) + + def test_x_dtype_check(): + # the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64] + x1 = fluid.layers.data( + name='x1', shape=[10, 5, 3], dtype='bool') + fluid.layers.transpose(x1, perm=[1, 0, 2]) + + self.assertRaises(TypeError, test_x_dtype_check) + + def test_perm_list_check(): + # Input(perm)'s type must be list + fluid.layers.transpose(x, perm="[1, 0, 2]") + + self.assertRaises(TypeError, test_perm_list_check) + + def test_perm_length_and_x_dim_check(): + # Input(perm) is the permutation of dimensions of Input(input) + # its length should be equal to dimensions of Input(input) + fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4]) + + self.assertRaises(ValueError, test_perm_length_and_x_dim_check) + + def test_each_elem_value_check(): + # Each element in Input(perm) should be less than Input(x)'s dimension + fluid.layers.transpose(x, perm=[3, 5, 7]) + + self.assertRaises(ValueError, test_each_elem_value_check) + + if __name__ == '__main__': unittest.main()