From 284519561db9f551a3c94978e9dfa8fd55358054 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 17 Sep 2018 14:53:55 +0800 Subject: [PATCH] add doc --- paddle/fluid/operators/sequence_concat_op.cc | 8 +++++-- python/paddle/fluid/layers/nn.py | 25 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/sequence_concat_op.cc b/paddle/fluid/operators/sequence_concat_op.cc index 1be236e2a..c989c1cb3 100644 --- a/paddle/fluid/operators/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_concat_op.cc @@ -38,9 +38,13 @@ class SeqConcatShapeInferer : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { try { - PADDLE_ENFORCE(context->HasInputs("X")); - PADDLE_ENFORCE(context->HasOutput("Out")); + PADDLE_ENFORCE(context->HasInputs("X"), + "Input(X) of Sequence Concat Op should not be null."); + PADDLE_ENFORCE(context->HasOutput("Out"), + "Output(Out) of Sequence Concat Op should not be null."); + PADDLE_ENFORCE_GT(context->HasInputs("X"), 1, + "The number of input sequences is at least two."); auto x_dims = context->GetInputsDim("X"); int64_t batch_size = 0; int64_t feature_size = 0; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4a2e6025c..f148fddae 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1781,6 +1781,31 @@ def sequence_pool(input, pool_type): return pool_out +@templatedoc() +def sequence_concat(input, name=None): + """ + ${comment} + + Args: + input(list): List of Variables to be concatenated. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: Output variable of the concatenation. + + Examples: + .. code-block:: python + + out = fluid.layers.sequence_concat(input=[seq1, seq2, seq3]) + """ + helper = LayerHelper('sequence_concat', **locals()) + out = helper.create_tmp_variable(dtype=helper.input_dtype()) + helper.append_op( + type='sequence_concat', inputs={'X': input}, outputs={'Out': [out]}) + return out + + def sequence_first_step(input): """ This function gets the first step of sequence. -- GitLab