提交 c083a60d 编写于 作者: G guosheng

Add python split and glu

上级 9867a379
......@@ -348,3 +348,9 @@ reduce_min
.. autofunction:: paddle.v2.fluid.layers.reduce_min
:noindex:
split
-----
.. autofunction:: paddle.v2.fluid.layers.split
:noindex:
......@@ -20,3 +20,8 @@ sequence_conv_pool
:noindex:
glu
---
.. autofunction:: paddle.v2.fluid.nets.glu
:noindex:
......@@ -60,6 +60,12 @@ class SplitOp : public framework::OperatorWithKernel {
}
}
ctx->SetOutputsDim("Out", outs_dims);
if (axis != 0) {
// Only pass LoD when not spliting along the first dim.
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
}
};
......
......@@ -14,7 +14,7 @@ __all__ = [
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'sequence_first_step', 'sequence_last_step', 'dropout'
'sequence_first_step', 'sequence_last_step', 'dropout', 'split'
]
......@@ -1504,3 +1504,63 @@ def reduce_min(input, dim=None, keep_dim=False):
'reduce_all': True if dim == None else False
})
return out
def split(input, num_or_sections, dim=-1):
"""
Splits the tensor into multiple sub-tensors.
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
then the integer indicates the number of equal sized sub-tensors
that the tensor will be divided into. If :attr:`num_or_sections`
is a list of integers, the length of list indicates the number of
sub-tensors and the integers indicate the sizes of sub-tensors'
:attr:`dim` dimension orderly.
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dimension to split along is :math:`rank(input) + dim`.
Returns:
List: The list of segmented tensor variables.
Examples:
.. code-block:: python
# x is a Tensor variable with shape [3, 9, 5]:
x0, x1, x2 = fluid.layers.split(x, num_or_sections=3, dim=1)
x0.shape # [3, 3, 5]
x1.shape # [3, 3, 5]
x2.shape # [3, 3, 5]
x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1)
x0.shape # [3, 2, 5]
x1.shape # [3, 3, 5]
x2.shape # [3, 4, 5]
"""
helper = LayerHelper('split', **locals())
input_shape = input.shape
dim = (len(input_shape) + dim) if dim < 0 else dim
if isinstance(num_or_sections, int):
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
assert input_shape[
dim] % num_or_sections == 0, 'num_or_sections must evenly divide input.shape[dim].'
num = num_or_sections
else:
assert len(num_or_sections) < input_shape[
dim], 'len(num_or_sections) must not be more than input.shape[dim].'
num = len(num_or_sections)
outs = [
helper.create_tmp_variable(dtype=helper.input_dtype())
for i in range(num)
]
helper.append_op(
type='split',
inputs={'X': input},
outputs={'Out': outs},
attrs={
'num': num_or_sections if isinstance(num_or_sections, int) else 0,
'sections': num_or_sections
if isinstance(num_or_sections, list) else [],
'axis': dim
})
return outs
import layers
__all__ = ["simple_img_conv_pool", "sequence_conv_pool"]
__all__ = ["simple_img_conv_pool", "sequence_conv_pool", "glu"]
def simple_img_conv_pool(input,
......@@ -98,3 +98,36 @@ def sequence_conv_pool(input,
pool_out = layers.sequence_pool(input=conv_out, pool_type=pool_type)
return pool_out
def glu(input, dim=-1):
"""
The gated linear unit composed by split and elementwise multiplication.
Specifically, Split the input into two equal sized parts :math:`a` and
:math:`b` along the given dimension and then compute as following:
.. math::
{GLU}(a, b)= a \otimes \sigma(b)
Refer to `Language Modeling with Gated Convolutional Networks
<https://arxiv.org/pdf/1612.08083.pdf>`_.
Args:
input (Variable): The input variable which is a Tensor or LoDTensor.
dim (int): The dimension along which to split. If :math:`dim < 0`, the
dimension to split along is :math:`rank(input) + dim`.
Returns:
Variable: The Tensor variable with half the size of input.
Examples:
.. code-block:: python
# x is a Tensor variable with shape [3, 6, 9]
fluid.nets.glu(input=x, dim=-1) # shape of output: [3, 3, 9]
"""
a, b = layers.split(input, num_or_sections=2, dim=dim)
out = layers.elementwise_mul(x=a, y=b)
return out
......@@ -6,8 +6,8 @@ import numpy
class TestReorderLoDTensor(unittest.TestCase):
num_seq = 5
# [name, dim, lod_level] pair indicating data info of source and target
data_desc = (['input', 9, 0], ['ref', 5, 1])
# [name, shape, lod_level] pair indicating data info of source and target
data_desc = (['input', [9], 0], ['ref', [5], 1])
@classmethod
def setUpClass(cls):
......@@ -16,10 +16,10 @@ class TestReorderLoDTensor(unittest.TestCase):
@classmethod
def set_program(cls):
dat = fluid.layers.data(
name=cls.data_desc[0][0], shape=[cls.data_desc[0][1]])
name=cls.data_desc[0][0], shape=cls.data_desc[0][1])
dat.stop_gradient = False
rank_dat = fluid.layers.data(
name=cls.data_desc[1][0], shape=[cls.data_desc[1][1]])
name=cls.data_desc[1][0], shape=cls.data_desc[1][1])
table = fluid.layers.lod_rank_table(rank_dat)
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=dat, rank_table=table)
......@@ -49,7 +49,7 @@ class TestReorderLoDTensor(unittest.TestCase):
self.data = {}
for desc in self.data_desc:
data_name = desc[0]
data_dim = desc[1]
data_shape = desc[1]
data_lod_level = desc[2]
data_lod = []
for i in range(data_lod_level):
......@@ -59,9 +59,9 @@ class TestReorderLoDTensor(unittest.TestCase):
size=self.num_seq if i == 0 else lod_level_i[-1])
lod_level_i = [0] + numpy.cumsum(lod_level_i).tolist()
data_lod.append(lod_level_i)
data_value = numpy.random.random(size=[
data_lod[-1][-1] if data_lod else self.num_seq, data_dim
]).astype('float32')
data_value = numpy.random.random(
size=[data_lod[-1][-1] if data_lod else self.num_seq
] + data_shape).astype('float32')
self.data[data_name] = (data_value, data_lod)
def set_inputs(self, place):
......@@ -163,8 +163,6 @@ class TestReorderLoDTensor(unittest.TestCase):
numpy.allclose(
numpy.array(actual_grad), expect_grad, atol=0.001))
self.assertEqual(expect_grad_lod, actual_grad.lod())
global outputs_from_tensor_implicit_lod
outputs_from_tensor_implicit_lod = self.actual_outputs
# compare outputs between LodTensors with explicit and implicit lod
# use the same data but set the input lod explicitly
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册