From c083a60d7a100d5ebafa16be46ce7335eae25e69 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 15 Jan 2018 14:43:31 +0800 Subject: [PATCH] Add python split and glu --- doc/api/v2/fluid/layers.rst | 6 ++ doc/api/v2/fluid/nets.rst | 5 ++ paddle/operators/split_op.cc | 6 ++ python/paddle/v2/fluid/layers/nn.py | 62 ++++++++++++++++++- python/paddle/v2/fluid/nets.py | 35 ++++++++++- .../v2/fluid/tests/test_reorder_lod_tensor.py | 18 +++--- 6 files changed, 120 insertions(+), 12 deletions(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index a7c8670f66..f1a2f7f880 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -348,3 +348,9 @@ reduce_min .. autofunction:: paddle.v2.fluid.layers.reduce_min :noindex: + +split +----- +.. autofunction:: paddle.v2.fluid.layers.split + :noindex: + diff --git a/doc/api/v2/fluid/nets.rst b/doc/api/v2/fluid/nets.rst index b792efb71f..cca0dcdf08 100644 --- a/doc/api/v2/fluid/nets.rst +++ b/doc/api/v2/fluid/nets.rst @@ -20,3 +20,8 @@ sequence_conv_pool :noindex: +glu +--- +.. autofunction:: paddle.v2.fluid.nets.glu + :noindex: + diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index 4dfae043cb..8d55ae5dd7 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -60,6 +60,12 @@ class SplitOp : public framework::OperatorWithKernel { } } ctx->SetOutputsDim("Out", outs_dims); + if (axis != 0) { + // Only pass LoD when not spliting along the first dim. + for (size_t i = 0; i < outs_number; ++i) { + ctx->ShareLoD("X", "Out", 0, i); + } + } } }; diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 48a6bee558..929249be40 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -14,7 +14,7 @@ __all__ = [ 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', - 'sequence_first_step', 'sequence_last_step', 'dropout' + 'sequence_first_step', 'sequence_last_step', 'dropout', 'split' ] @@ -1504,3 +1504,63 @@ def reduce_min(input, dim=None, keep_dim=False): 'reduce_all': True if dim == None else False }) return out + + +def split(input, num_or_sections, dim=-1): + """ + Splits the tensor into multiple sub-tensors. + + Args: + input (Variable): The input variable which is a Tensor or LoDTensor. + num_or_sections (int|list): If :attr:`num_or_sections` is an integer, + then the integer indicates the number of equal sized sub-tensors + that the tensor will be divided into. If :attr:`num_or_sections` + is a list of integers, the length of list indicates the number of + sub-tensors and the integers indicate the sizes of sub-tensors' + :attr:`dim` dimension orderly. + dim (int): The dimension along which to split. If :math:`dim < 0`, the + dimension to split along is :math:`rank(input) + dim`. + + Returns: + List: The list of segmented tensor variables. + + Examples: + .. code-block:: python + + # x is a Tensor variable with shape [3, 9, 5]: + x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1) + x0.shape # [3, 3, 5] + x1.shape # [3, 3, 5] + x2.shape # [3, 3, 5] + x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1) + x0.shape # [3, 2, 5] + x1.shape # [3, 3, 5] + x2.shape # [3, 4, 5] + """ + helper = LayerHelper('split', **locals()) + input_shape = input.shape + dim = (len(input_shape) + dim) if dim < 0 else dim + if isinstance(num_or_sections, int): + assert num_or_sections > 1, 'num_or_sections must be more than 1.' + assert input_shape[ + dim] % num_or_sections == 0, 'num_or_sections must evenly divide input.shape[dim].' + num = num_or_sections + else: + assert len(num_or_sections) < input_shape[ + dim], 'len(num_or_sections) must not be more than input.shape[dim].' + num = len(num_or_sections) + outs = [ + helper.create_tmp_variable(dtype=helper.input_dtype()) + for i in range(num) + ] + helper.append_op( + type='split', + inputs={'X': input}, + outputs={'Out': outs}, + attrs={ + 'num': num_or_sections if isinstance(num_or_sections, int) else 0, + 'sections': num_or_sections + if isinstance(num_or_sections, list) else [], + 'axis': dim + }) + return outs diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 54886a8f2c..afba32e7b6 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -1,6 +1,6 @@ import layers -__all__ = ["simple_img_conv_pool", "sequence_conv_pool"] +__all__ = ["simple_img_conv_pool", "sequence_conv_pool", "glu"] def simple_img_conv_pool(input, @@ -98,3 +98,36 @@ def sequence_conv_pool(input, pool_out = layers.sequence_pool(input=conv_out, pool_type=pool_type) return pool_out + + +def glu(input, dim=-1): + """ + The gated linear unit composed by split and elementwise multiplication. + Specifically, Split the input into two equal sized parts :math:`a` and + :math:`b` along the given dimension and then compute as following: + + .. math:: + + {GLU}(a, b)= a \otimes \sigma(b) + + Refer to `Language Modeling with Gated Convolutional Networks + `_. + + Args: + input (Variable): The input variable which is a Tensor or LoDTensor. + dim (int): The dimension along which to split. If :math:`dim < 0`, the + dimension to split along is :math:`rank(input) + dim`. + + Returns: + Variable: The Tensor variable with half the size of input. + + Examples: + .. code-block:: python + + # x is a Tensor variable with shape [3, 6, 9] + fluid.nets.glu(input=x, dim=-1) # shape of output: [3, 3, 9] + """ + + a, b = layers.split(input, num_or_sections=2, dim=dim) + out = layers.elementwise_mul(x=a, y=b) + return out diff --git a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py index 8b79d448e2..215accd4c6 100644 --- a/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py +++ b/python/paddle/v2/fluid/tests/test_reorder_lod_tensor.py @@ -6,8 +6,8 @@ import numpy class TestReorderLoDTensor(unittest.TestCase): num_seq = 5 - # [name, dim, lod_level] pair indicating data info of source and target - data_desc = (['input', 9, 0], ['ref', 5, 1]) + # [name, shape, lod_level] pair indicating data info of source and target + data_desc = (['input', [9], 0], ['ref', [5], 1]) @classmethod def setUpClass(cls): @@ -16,10 +16,10 @@ class TestReorderLoDTensor(unittest.TestCase): @classmethod def set_program(cls): dat = fluid.layers.data( - name=cls.data_desc[0][0], shape=[cls.data_desc[0][1]]) + name=cls.data_desc[0][0], shape=cls.data_desc[0][1]) dat.stop_gradient = False rank_dat = fluid.layers.data( - name=cls.data_desc[1][0], shape=[cls.data_desc[1][1]]) + name=cls.data_desc[1][0], shape=cls.data_desc[1][1]) table = fluid.layers.lod_rank_table(rank_dat) new_dat = fluid.layers.reorder_lod_tensor_by_rank( x=dat, rank_table=table) @@ -49,7 +49,7 @@ class TestReorderLoDTensor(unittest.TestCase): self.data = {} for desc in self.data_desc: data_name = desc[0] - data_dim = desc[1] + data_shape = desc[1] data_lod_level = desc[2] data_lod = [] for i in range(data_lod_level): @@ -59,9 +59,9 @@ class TestReorderLoDTensor(unittest.TestCase): size=self.num_seq if i == 0 else lod_level_i[-1]) lod_level_i = [0] + numpy.cumsum(lod_level_i).tolist() data_lod.append(lod_level_i) - data_value = numpy.random.random(size=[ - data_lod[-1][-1] if data_lod else self.num_seq, data_dim - ]).astype('float32') + data_value = numpy.random.random( + size=[data_lod[-1][-1] if data_lod else self.num_seq + ] + data_shape).astype('float32') self.data[data_name] = (data_value, data_lod) def set_inputs(self, place): @@ -163,8 +163,6 @@ class TestReorderLoDTensor(unittest.TestCase): numpy.allclose( numpy.array(actual_grad), expect_grad, atol=0.001)) self.assertEqual(expect_grad_lod, actual_grad.lod()) - global outputs_from_tensor_implicit_lod - outputs_from_tensor_implicit_lod = self.actual_outputs # compare outputs between LodTensors with explicit and implicit lod # use the same data but set the input lod explicitly -- GitLab