diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 11615d806a61b3525d2ed50f5ea5940e8d61c8f8..c7ae162638ca5e929cca14c841cc3eceeea5f64e 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -59,44 +59,39 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "X", - "(Tensor)The input tensor, tensors with rank at most 6 are supported"); - AddOutput("Out", "(Tensor)The output tensor"); + "(Tensor) The input tensor, tensors with rank up to 6 are supported."); + AddOutput("Out", "(Tensor)The output tensor."); AddAttr>( "axis", - "(vector)A list of values, and the size of the list should be " - "the same with the input tensor rank, the tensor will " - "permute the axes according the the values given"); + "(vector) A list of values, and the size of the list should be " + "the same with the input tensor rank. This operator permutes the input " + "tensor's axes according to the values given."); AddComment(R"DOC( Transpose Operator. -The input tensor will be permuted according to the axis values given. -The op functions is similar to how numpy.transpose works in python. +The input tensor will be permuted according to the axes given. +The behavior of this operator is similar to how `numpy.transpose` works. -For example: +- suppose the input `X` is a 2-D tensor: + $$ + X = \begin{pmatrix} + 0 &1 &2 \\ + 3 &4 &5 + \end{pmatrix}$$ - .. code-block:: text + the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis) - input = numpy.arange(6).reshape((2,3)) + then the output $Y$ is: - the input is: + $$ + Y = \begin{pmatrix} + 0 &3 \\ + 1 &4 \\ + 2 &5 + \end{pmatrix}$$ - array([[0, 1, 2], - [3, 4, 5]]) - - given axis is: - - [1, 0] - - output = input.transpose(axis) - - then the output is: - - array([[0, 3], - [1, 4], - [2, 5]]) - -So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1}, -the output tensor shape will be (N, H, W, C) +- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is +$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$. )DOC"); } diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py index bbc28a2da99052308471931122946d0d96b54da5..e2f463be2f7bcd667855f64206d78f387e92ef33 100644 --- a/python/paddle/v2/dataset/wmt16.py +++ b/python/paddle/v2/dataset/wmt16.py @@ -171,8 +171,9 @@ def train(src_dict_size, trg_dict_size, src_lang="en"): callable: The train reader. """ - assert (src_lang in ["en", "de"], ("An error language type. Only support: " - "en (for English); de(for Germany)")) + if src_lang not in ["en", "de"]: + raise ValueError("An error language type. Only support: " + "en (for English); de(for Germany).") src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_lang) @@ -218,9 +219,9 @@ def test(src_dict_size, trg_dict_size, src_lang="en"): callable: The test reader. """ - assert (src_lang in ["en", "de"], - ("An error language type. " - "Only support: en (for English); de(for Germany)")) + if src_lang not in ["en", "de"]: + raise ValueError("An error language type. " + "Only support: en (for English); de(for Germany).") src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_lang) @@ -266,9 +267,9 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"): Returns: callable: The validation reader. """ - assert (src_lang in ["en", "de"], - ("An error language type. " - "Only support: en (for English); de(for Germany)")) + if src_lang not in ["en", "de"]: + raise ValueError("An error language type. " + "Only support: en (for English); de(for Germany).") src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_lang) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index f0345512f5133573f3f946878af1939ad1d7fcd3..a01ccfa635301108e337668188dd14ed0c0b1d8a 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -22,14 +22,41 @@ from ..param_attr import ParamAttr from tensor import concat __all__ = [ - 'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf', - 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', - 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', - 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', - 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', - 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', - 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc', - 'sequence_reshape' + 'fc', + 'embedding', + 'dynamic_lstm', + 'gru_unit', + 'linear_chain_crf', + 'crf_decoding', + 'cos_sim', + 'cross_entropy', + 'square_error_cost', + 'accuracy', + 'chunk_eval', + 'sequence_conv', + 'conv2d', + 'sequence_pool', + 'pool2d', + 'batch_norm', + 'beam_search_decode', + 'conv2d_transpose', + 'sequence_expand', + 'lstm_unit', + 'reduce_sum', + 'reduce_mean', + 'reduce_max', + 'reduce_min', + 'sequence_first_step', + 'sequence_last_step', + 'dropout', + 'split', + 'ctc_greedy_decoder', + 'edit_distance', + 'l2_normalize', + 'matmul', + 'warpctc', + 'sequence_reshape', + 'transpose', ] @@ -44,14 +71,14 @@ def fc(input, **Fully Connected Layer** The fully connected layer can take multiple tensors as its inputs. It - creates a variable (one for each input tensor) called weights for each input - tensor, which represents a fully connected weight matrix from each input - unit to each output unit. The fully connected layer multiplies each input - tensor with its coresponding weight to produce an output Tensor. If - multiple input tensors are given, the results of multiple multiplications - will be sumed up. If bias_attr is not None, a biases variable will be - created and added to the output. Finally, if activation is not None, - it will be applied to the output as well. + creates a variable (one for each input tensor) called weights for each + input tensor, which represents a fully connected weight matrix from + each input unit to each output unit. The fully connected layer + multiplies each input tensor with its coresponding weight to produce + an output Tensor. If multiple input tensors are given, the results of + multiple multiplications will be sumed up. If bias_attr is not None, + a biases variable will be created and added to the output. Finally, + if activation is not None, it will be applied to the output as well. This process can be formulated as follows: @@ -1814,11 +1841,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): - If both are 2-D, they are multiplied like conventional matrices. - If either is n-D, it is treated as a stack of matrices residing in the - last two dimensions and a batched matrix multiply supporting broadcast + last two dimensions and a batched matrix multiply supporting broadcast applies on the two tensors. - Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and - nontransposed, the prepended or appended dimension :math:`1` will be + Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and + nontransposed, the prepended or appended dimension :math:`1` will be removed after matrix multiplication. Args: @@ -2112,3 +2139,41 @@ def sequence_reshape(input, new_dim): outputs={'Out': [out]}, attrs={'new_dim': new_dim}) return out + + +def transpose(x, perm, name=None): + """ + **transpose Layer** + + Permute the dimensions of `input` according to `perm`. + + The `i`-th dimension of the returned tensor will correspond to the + perm[i]-th dimension of `input`. + + Args: + input (Variable): (Tensor), A Tensor. + perm (list): A permutation of the dimensions of `input`. + + Returns: + Variable: A transposed Tensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[5, 10, 15], dtype='float32') + x_transposed = layers.transpose(x, perm=[1, 0, 2]) + """ + + if len(perm) != len(x.shape): + raise ValueError( + "Input(perm) is the permutation of dimensions of Input(input). " + "It's length shoud be equal to Input(input)'s rank.") + + helper = LayerHelper('transpose', **locals()) + out = helper.create_tmp_variable(x.dtype) + helper.append_op( + type='transpose', + inputs={'X': [x]}, + outputs={'Out': [out]}, + attrs={'axis': perm}) + return out diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index b517f8be6a3e5558dd01afe094fb3989cfb3af44..022a94cad440f13383a927233195bb008a688843 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -45,10 +45,20 @@ __activations__ = [ ] __all__ = [ - 'mean', 'mul', 'reshape', 'scale', 'transpose', - 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', - 'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min', - 'clip', 'clip_by_norm', 'sequence_softmax' + 'mean', + 'mul', + 'reshape', + 'scale', + 'sigmoid_cross_entropy_with_logits', + 'elementwise_add', + 'elementwise_div', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_max', + 'elementwise_min', + 'clip', + 'clip_by_norm', + 'sequence_softmax', ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py index 618191424150eb7c5a24407fc2e106ee8825fedb..117f74c59ad5bf6bb67711801cd7b9a41f39f1f8 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py @@ -65,13 +65,13 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50): emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = fluid.layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) - emb = fluid.layers.transpose(x=emb, axis=[1, 0, 2]) + emb = fluid.layers.transpose(x=emb, perm=[1, 0, 2]) c_pre_init = fluid.layers.fill_constant( dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0) c_pre_init.stop_gradient = False layer_1_out = lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) - layer_1_out = fluid.layers.transpose(x=layer_1_out, axis=[1, 0, 2]) + layer_1_out = fluid.layers.transpose(x=layer_1_out, perm=[1, 0, 2]) prediction = fluid.layers.fc(input=layer_1_out, size=class_dim,