From 32b9577b2a7a6b53555f30aa7cc45e0907b2b366 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Thu, 30 Jul 2020 21:02:28 +0800 Subject: [PATCH] refine the split op for API 2.0 test=develop (#25320) --- paddle/fluid/operators/split_op.cc | 1 + paddle/fluid/operators/split_op.cu.cc | 1 + python/paddle/fluid/layers/nn.py | 91 ++++----- .../fluid/tests/unittests/test_split_op.py | 66 ++++++- python/paddle/tensor/manipulation.py | 173 +++++------------- 5 files changed, 160 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index abb21acb62d..0157f0635b8 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -150,4 +150,5 @@ REGISTER_OP_CPU_KERNEL( ops::SplitOpKernel, ops::SplitOpKernel, ops::SplitOpKernel, + ops::SplitOpKernel, ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.cu.cc b/paddle/fluid/operators/split_op.cu.cc index bbdac686a29..d1da64b158c 100644 --- a/paddle/fluid/operators/split_op.cu.cc +++ b/paddle/fluid/operators/split_op.cu.cc @@ -20,4 +20,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SplitOpKernel, ops::SplitOpKernel, ops::SplitOpKernel, + ops::SplitOpKernel, ops::SplitOpKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b73d10ff4e7..46fb61745ae 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4813,47 +4813,57 @@ def split(input, num_or_sections, dim=-1, name=None): Split the input tensor into multiple sub-Tensors. Args: - input (Variable): The input variable which is an N-D Tensor or LoDTensor, data type being float32, float64, int32 or int64. - num_or_sections (int|list|tuple): If :attr:`num_or_sections` is an integer, - then the integer indicates the number of equal sized sub-Tensors - that the Tensor will be divided into. If :attr:`num_or_sections` - is a list or tuple, the length of it indicates the number of - sub-Tensors and the elements in it indicate the sizes of sub-Tensors' - :attr:`dim` dimension orderly. The length of the list mustn't be larger than the Tensor's size of :attr:`dim` . - dim (int32|Varible, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. The dimension along which to split. If :math:`dim < 0`, the - dimension to split along is :math:`rank(input) + dim`. Default is -1. - name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . + input (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64. + num_or_sections (int|list|tuple): If ``num_or_sections`` is int, then the ``num_or_sections`` + indicates the number of equal sized sub-Tensors that the ``input`` + will be divided into. If ``num_or_sections`` is a list or tuple, the length of it + indicates the number of sub-Tensors and the elements in it indicate the sizes of sub-Tensors' + dimension orderly. The length of the list mustn't be larger than the ``input`` 's size of specified dim. + dim (int|Tensor, optional): The dimension along which to split, it can be a scalar with type ``int`` or + a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``. If :math:`dim < 0`, + the dimension to split along is :math:`rank(input) + dim`. Default is -1. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - list(Variable): The list of segmented Tensor variables. + list(Tensor): The list of segmented Tensors. Raises: - TypeError: num_or_sections is not int, list or tuple. - TypeError: dim is not int or Variable. + TypeError: The data type of ``input`` must be one of bool, float16, float32, float64, int32, int64. + TypeError: ``num_or_sections`` is not int, list or tuple. + TypeError: ``dim`` is not int or Tensor. The data type of ``dim`` must be int32 or int64 when it's a Tensor. Example: .. code-block:: python import paddle.fluid as fluid - # input is a variable which shape is [3, 9, 5] + # input is a Tensor which shape is [3, 9, 5] input = fluid.data( name="input", shape=[3, 9, 5], dtype="float32") - x0, x1, x2 = fluid.layers.split(input, num_or_sections=3, dim=1) - # x0.shape [3, 3, 5] - # x1.shape [3, 3, 5] - # x2.shape [3, 3, 5] + out0, out1, out2 = fluid.layers.split(input, num_or_sections=3, dim=1) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] - x0, x1, x2 = fluid.layers.split(input, num_or_sections=[2, 3, 4], dim=1) - # x0.shape [3, 2, 5] - # x1.shape [3, 3, 5] - # x2.shape [3, 4, 5] + out0, out1, out2 = fluid.layers.split(input, num_or_sections=[2, 3, 4], dim=1) + # out0.shape [3, 2, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 4, 5] + + out0, out1, out2 = fluid.layers.split(input, num_or_sections=[2, 3, -1], dim=1) + # out0.shape [3, 2, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 4, 5] + + # dim is negative, the real dim is (rank(input) + axis) which real + # value is 1. + out0, out1, out2 = fluid.layers.split(input, num_or_sections=3, dim=-2) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] - x0, x1, x2 = fluid.layers.split(input, num_or_sections=[2, 3, -1], dim=1) - # x0.shape [3, 2, 5] - # x1.shape [3, 3, 5] - # x2.shape [3, 4, 5] """ if in_dygraph_mode(): num = None @@ -4861,8 +4871,6 @@ def split(input, num_or_sections, dim=-1, name=None): if isinstance(dim, Variable): dim = dim.numpy() - assert dim.shape == (1, - ), "dim of type Variable should have shape [1]" dim = dim[0] dim = (len(input.shape) + dim) if dim < 0 else dim attrs += ('axis', dim) @@ -4873,28 +4881,29 @@ def split(input, num_or_sections, dim=-1, name=None): elif isinstance(num_or_sections, (list, tuple)): num = len(num_or_sections) if utils._contain_var(num_or_sections): - raise TypeError( - "The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but " - "received %s, which contains Variable." % - (type(num_or_sections))) + for index, item in enumerate(num_or_sections): + if isinstance(item, Variable): + num_or_sections[index] = num_or_sections[index].numpy()[ + 0] + attrs += ('sections', list(num_or_sections)) else: attrs += ('sections', list(num_or_sections)) else: raise TypeError( - "The type of 'num_or_sections' in split must be int or list in Dygraph mode, but " + "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " "received %s." % (type(num_or_sections))) return core.ops.split(input, num, *attrs) - if not isinstance(num_or_sections, (int, list, tuple)): - raise TypeError( - "The type of 'num_or_sections' in split must be int, list or " - "tuple, but received %s." % (type(num_or_sections))) - if not isinstance(dim, (int, Variable)): - raise TypeError( - "The type of 'dim' in split must be int or Variable, but " - "received %s." % (type(dim))) + check_variable_and_dtype( + input, 'input', + ['bool', 'float16', 'float32', 'float64', 'int32', 'in64'], 'split') + check_type(num_or_sections, 'num_or_sections', (list, int, tuple), 'split') + check_type(dim, 'dim', (int, Variable), 'split') + if isinstance(dim, Variable): + check_dtype(dim.dtype, 'dim', ['int32', 'int64'], 'split') helper = LayerHelper('split', **locals()) + input_shape = input.shape inputs = {'X': input} attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0} diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 2fa6c7735c5..b261ce93c0a 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -271,6 +271,14 @@ class TestSplitOpError(unittest.TestCase): self.assertRaises(TypeError, test_axis_type) + # The type of axis in split_op should be int or Variable. + def test_axis_variable_type(): + x9 = fluid.layers.data(shape=[4], dtype='float16', name='x9') + x10 = fluid.layers.data(shape=[1], dtype='float16', name='x10') + fluid.layers.split(input=x9, num_or_sections=2, dim=x10) + + self.assertRaises(TypeError, test_axis_variable_type) + # The type of num_or_sections in split_op should be int, tuple or list. def test_num_or_sections_type(): x6 = fluid.layers.data(shape=[4], dtype='float16', name='x4') @@ -296,7 +304,7 @@ class API_TestSplit(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') data2 = fluid.layers.data('data2', shape=[1], dtype='int32') - x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=data2) + x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=data2) place = fluid.CPUPlace() exe = fluid.Executor(place) input1 = np.random.random([4, 6, 6]).astype('float64') @@ -314,7 +322,7 @@ class API_TestSplit2(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program(), fluid.Program()): data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') - x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=2) + x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=2) place = fluid.CPUPlace() exe = fluid.Executor(place) input1 = np.random.random([4, 6, 6]).astype('float64') @@ -330,7 +338,7 @@ class API_TestSplit3(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program(), fluid.Program()): data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') - x0, x1 = paddle.split(data, num_or_sections=(3, 7), dim=1) + x0, x1 = paddle.split(data, num_or_sections=(3, 7), axis=1) place = fluid.CPUPlace() exe = fluid.Executor(place) input1 = np.random.random([1, 10]).astype('float64') @@ -345,7 +353,7 @@ class API_TestSplit4(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') index = fluid.layers.data('index', shape=[1], dtype='int32') - x0, x1 = paddle.split(data, num_or_sections=(3, index), dim=1) + x0, x1 = paddle.split(data, num_or_sections=(3, index), axis=1) place = fluid.CPUPlace() exe = fluid.Executor(place) input1 = np.random.random([1, 10]).astype('float64') @@ -359,12 +367,58 @@ class API_TestSplit4(unittest.TestCase): class API_TestDygraphSplit(unittest.TestCase): - def test_out(self): + def test_out1(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + def test_out2(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("bool") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + def test_out_tensor_input(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + num1 = paddle.full(shape=[1], fill_value=2, dtype='int32') + x0, x1, x2 = paddle.split( + input, num_or_sections=[num1, 2, 2], axis=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + def test_axis_tensor_input(self): with fluid.dygraph.guard(): input_1 = np.random.random([4, 6, 6]).astype("int32") # input is a variable which shape is [4, 6, 6] input = fluid.dygraph.to_variable(input_1) - x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) + num1 = paddle.full(shape=[1], fill_value=1, dtype='int32') + x0, x1, x2 = paddle.split( + input, num_or_sections=[2, 2, 2], axis=num1) x0_out = x0.numpy() x1_out = x1.numpy() x2_out = x2.numpy() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 642b5db2837..e894375aedf 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -363,143 +363,66 @@ def stack(x, axis=0, out=None, name=None): return out -def split(input, num_or_sections, dim=-1, name=None): +def split(x, num_or_sections, axis=0, name=None): """ :alias_main: paddle.split - :alias: paddle.split,paddle.tensor.split,paddle.tensor.manipulation.split - + :alias: paddle.tensor.split, paddle.tensor.manipulation.split + Split the input tensor into multiple sub-Tensors. + Args: - input (Variable): The input variable which is an N-D Tensor or LoDTensor, data type being float32, float64, int32 or int64. - num_or_sections (int|list|tuple): If :attr:`num_or_sections` is an integer, - then the integer indicates the number of equal sized sub-Tensors - that the Tensor will be divided into. If :attr:`num_or_sections` - is a list or tuple, the length of it indicates the number of - sub-Tensors and the elements in it indicate the sizes of sub-Tensors' - :attr:`dim` dimension orderly. The length of the list mustn't be larger than the Tensor's size of :attr:`dim` . - dim (int32|Varible, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. The dimension along which to split. If :math:`dim < 0`, the - dimension to split along is :math:`rank(input) + dim`. Default is -1. - name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . + x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64. + num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` + indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. + If ``num_or_sections`` is a list or tuple, the length of it indicates the number of + sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. + The length of the list must not be larger than the ``x`` 's size of specified ``axis``. + axis (int|Tensor, optional): The axis along which to split, it can be a scalar with type + ``int`` or a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``. + If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - list(Variable): The list of segmented Tensor variables. + list(Tensor): The list of segmented Tensors. Raises: - TypeError: num_or_sections is not int, list or tuple. - TypeError: dim is not int or Variable. + TypeError: The data type of ``x`` must be one of bool, float16, float32, float64, int32, int64. + TypeError: ``num_or_sections`` is not int, list or tuple. + TypeError: ``axis`` is not int or Tensor. the data type of ``axis`` must be int32 or int64 when it's a Tensor. Example: .. code-block:: python + import numpy as np import paddle - import paddle.fluid as fluid - with fluid.dygraph.guard(): - input_1 = np.random.random([4, 6, 6]).astype("int32") - # input is a variable which shape is [4, 6, 6] - input = fluid.dygraph.to_variable(input_1) - - x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) - # x0.shape [4, 2, 6] - # x1.shape [4, 2, 6] - # x2.shape [4, 2, 6] + paddle.enable_imperative() + # x is a Tensor which shape is [3, 9, 5] + x_np = np.random.random([3, 9, 5]).astype("int32") + x = paddle.imperative.to_variable(x_np) + + out0, out1, out22 = paddle.split(x, num_or_sections=3, axis=1) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] + + out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, 4], axis=1) + # out0.shape [3, 2, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 4, 5] + + out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1], axis=1) + # out0.shape [3, 2, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 4, 5] + + # axis is negative, the real axis is (rank(x) + axis) which real + # value is 1. + out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=-2) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] """ - if in_dygraph_mode(): - num = None - attrs = () - - if isinstance(dim, Variable): - dim = dim.numpy() - assert dim.shape == (1, - ), "dim of type Variable should have shape [1]" - dim = dim[0] - dim = (len(input.shape) + dim) if dim < 0 else dim - attrs += ('axis', dim) - - if isinstance(num_or_sections, int): - num = num_or_sections - attrs += ('num', num_or_sections) - elif isinstance(num_or_sections, (list, tuple)): - num = len(num_or_sections) - if utils._contain_var(num_or_sections): - raise TypeError( - "The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but " - "received %s, which contains Variable." % - (type(num_or_sections))) - else: - attrs += ('sections', list(num_or_sections)) - else: - raise TypeError( - "The type of 'num_or_sections' in split must be int or list in Dygraph mode, but " - "received %s." % (type(num_or_sections))) - return core.ops.split(input, num, *attrs) - - if not isinstance(num_or_sections, (int, list, tuple)): - raise TypeError( - "The type of 'num_or_sections' in split must be int, list or " - "tuple, but received %s." % (type(num_or_sections))) - if not isinstance(dim, (int, Variable)): - raise TypeError( - "The type of 'dim' in split must be int or Variable, but " - "received %s." % (type(dim))) - - helper = LayerHelper('split', **locals()) - input_shape = input.shape - inputs = {'X': input} - attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0} - - def _get_SectionsTensorList(one_list): - tensor_list = [] - unk_dim_idx = -1 - for idx, dim_size in enumerate(one_list): - if isinstance(dim_size, Variable): - dim_size.stop_gradient = True - tensor_list.append(dim_size) - else: - assert (isinstance(dim_size, int)) - if dim_size == -1: - assert unk_dim_idx == -1, ( - "Only one value of 'num_or_section' in split can " - "be -1. But received num_or_section[%d] is also -1." % - idx) - unk_dim_idx = idx - temp_out = helper.create_variable_for_type_inference('int32') - fill_constant( - [1], 'int32', dim_size, force_cpu=True, out=temp_out) - tensor_list.append(temp_out) - return tensor_list - - if isinstance(dim, Variable): - dim.stop_gradient = True - inputs['AxisTensor'] = dim - else: - dim = (len(input_shape) + dim) if dim < 0 else dim - attrs['axis'] = dim - - if isinstance(num_or_sections, int): - assert num_or_sections > 1, 'num_or_sections must be more than 1.' - if isinstance(dim, int) and input_shape[dim] > 0: - assert input_shape[dim] % num_or_sections ==0, \ - "The input's size along the split dimension " \ - "must be evenly divisible by Attr(num_or_sections). " \ - "But %d is not evenly divisible by %d. " % (num_or_sections,input_shape[dim]) - num = num_or_sections - else: - if isinstance(dim, int) and input_shape[dim] > 0: - assert len(num_or_sections) <= input_shape[ - dim], 'len(num_or_sections) must not be more than input.shape[dim].' - num = len(num_or_sections) - attrs['sections'] = list( - map(lambda ele: -1 if isinstance(ele, Variable) else ele, - num_or_sections)) - if utils._contain_var(num_or_sections): - inputs['SectionsTensorList'] = _get_SectionsTensorList( - num_or_sections) - - outs = [ - helper.create_variable_for_type_inference(dtype=helper.input_dtype()) - for i in range(num) - ] - helper.append_op( - type='split', inputs=inputs, outputs={'Out': outs}, attrs=attrs) - return outs + return paddle.fluid.layers.split( + input=x, num_or_sections=num_or_sections, dim=axis, name=name) def squeeze(x, axis=None, name=None): -- GitLab