From 219a2369da4c80318e75020acdcafb2971398143 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 23 Aug 2018 16:09:45 +0000 Subject: [PATCH] feat: wrap sequence enumerate op --- paddle/fluid/API.spec | 1 + .../fluid/operators/sequence_enumerate_op.cc | 20 +-- .../fluid/operators/sequence_enumerate_op.cu | 2 - .../fluid/operators/sequence_enumerate_op.h | 2 - python/paddle/fluid/layers/nn.py | 136 ++++++++---------- .../fluid/tests/unittests/test_layers.py | 9 ++ 6 files changed, 85 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e963902a502..c2a08d2e53c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -161,6 +161,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index 0d9fdf7d5ca..cacbb097771 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/sequence_enumerate_op.h" +#include namespace paddle { namespace operators { @@ -30,16 +31,21 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { "Output(X) of SequenceEnumerate operator should not be null."); const auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE( - x_dims.size() == 2 && x_dims[1] == 1, - "Input(X) of SequenceEnumerate operator should be a 2-D LoDTensor " - "with the 2nd dimension equal to 1."); + PADDLE_ENFORCE_EQ( + x_dims.size(), 2UL, + "Input(X) of SequenceEnumerate operator's rank should be 2."); const auto win_size = ctx->Attrs().Get("win_size"); - PADDLE_ENFORCE(win_size <= x_dims[0], + // TODO(chenweihang): unittest doesn't has batch size, but test_layers has + auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0]; + PADDLE_ENFORCE(win_size <= first_dim, "The enumerate window size should be less than or equal to " "input sequence length."); - ctx->SetOutputDim("Out", {x_dims[0], win_size}); + + std::vector out_shape(x_dims.size() + 1, 0); + for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]); + out_shape.emplace_back(win_size); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->ShareLoD("X", "Out"); } }; @@ -83,8 +89,6 @@ Case 1: Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] Out.dims = [5, 2] - Currently, only 1-level LoDTensor is supported. - )DOC"); } }; diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_enumerate_op.cu index 2e2356e7eca..e680174a2cf 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_enumerate_op.cu @@ -48,8 +48,6 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { auto in_dims = in->dims(); auto in_lod = in->lod(); - PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, - "Only support one level sequence now."); PADDLE_ENFORCE_EQ( static_cast(in_dims[0]), in_lod[0].back(), "The actual input data's size mismatched with LoD information."); diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_enumerate_op.h index 8e9549508e9..8a30003b164 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_enumerate_op.h @@ -32,8 +32,6 @@ class SequenceEnumerateKernel : public framework::OpKernel { auto in_dims = in->dims(); auto in_lod = in->lod(); - PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, - "Only support one level sequence now."); PADDLE_ENFORCE_EQ( static_cast(in_dims[0]), in_lod[0].back(), "The actual input data's size mismatched with LoD information."); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bd2b950cffe..d4efb682d95 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -29,79 +29,22 @@ from .. import unique_name from functools import reduce __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'conv3d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'pool3d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'conv3d_transpose', - 'sequence_expand', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'hsigmoid', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'lod_reset', - 'lrn', - 'pad', - 'label_smooth', - 'roi_pool', - 'dice_loss', - 'image_resize', - 'image_resize_short', - 'resize_bilinear', - 'gather', - 'random_crop', - 'mean_iou', - 'relu', - 'log', - 'crop', - 'rank_loss', - 'prelu', - 'flatten', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', 'conv3d', + 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'pool3d', + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose', + 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', + 'reduce_min', 'reduce_prod', 'sequence_first_step', 'sequence_last_step', + 'dropout', 'split', 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', + 'matmul', 'topk', 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', + 'nce', 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', + 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', + 'label_smooth', 'roi_pool', 'dice_loss', 'image_resize', + 'image_resize_short', 'resize_bilinear', 'gather', 'random_crop', + 'mean_iou', 'relu', 'log', 'crop', 'rank_loss', 'prelu', 'flatten', + 'sequence_enumerate' ] @@ -5475,3 +5418,50 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def sequence_enumerate(input, win_size, pad_value, name=None): + """ + Generate a new LoDTensor + with the same 1st dimension length as the original LoDTensor, + and with the 2nd dimension equal to the input window length, + the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence. + The values of the last insufficient part areall filled with the input pad_value. + + Examples: + Case 1: + Input: + X.lod = [[0, 3, 5]] + X.data = [1, 2, 3, 4, 5] + X.dims = [5, 1] + Attrs: + win_size = 2 + pad_value = 0 + Output: + Out.lod = [[0, 3, 5]] + Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] + Out.dims = [5, 2] + + Args: + input (Variable): The input variable which is a LoDTensor + win_size (int): The enumerate sequence window size. + pad_value (int): The enumerate sequence padding value. + + Returns: + Variable: The enumerate sequence variable which is a LoDTensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1) + out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) + """ + helper = LayerHelper('sequence_enumerate', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='sequence_enumerate', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'win_size': win_size, + 'pad_value': pad_value}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e833a7db482..4994e11d1fb 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -500,6 +500,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_sequence_enumerate(self): + program = Program() + with program_guard(program): + x = layers.data( + name="input", shape=[30], dtype='int32', lod_level=1) + out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() -- GitLab