未验证 提交 32b9577b 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the split op for API 2.0 test=develop (#25320)

上级 ce506930
...@@ -150,4 +150,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -150,4 +150,5 @@ REGISTER_OP_CPU_KERNEL(
ops::SplitOpKernel<plat::CPUDeviceContext, float>, ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>, ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>, ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, bool>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>); ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
...@@ -20,4 +20,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,4 +20,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::SplitOpKernel<plat::CUDADeviceContext, float>, ops::SplitOpKernel<plat::CUDADeviceContext, float>,
ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>, ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<plat::CUDADeviceContext, int>, ops::SplitOpKernel<plat::CUDADeviceContext, int>,
ops::SplitOpKernel<plat::CUDADeviceContext, bool>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>); ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -4813,47 +4813,57 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4813,47 +4813,57 @@ def split(input, num_or_sections, dim=-1, name=None):
Split the input tensor into multiple sub-Tensors. Split the input tensor into multiple sub-Tensors.
Args: Args:
input (Variable): The input variable which is an N-D Tensor or LoDTensor, data type being float32, float64, int32 or int64. input (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64.
num_or_sections (int|list|tuple): If :attr:`num_or_sections` is an integer, num_or_sections (int|list|tuple): If ``num_or_sections`` is int, then the ``num_or_sections``
then the integer indicates the number of equal sized sub-Tensors indicates the number of equal sized sub-Tensors that the ``input``
that the Tensor will be divided into. If :attr:`num_or_sections` will be divided into. If ``num_or_sections`` is a list or tuple, the length of it
is a list or tuple, the length of it indicates the number of indicates the number of sub-Tensors and the elements in it indicate the sizes of sub-Tensors'
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. The length of the list mustn't be larger than the ``input`` 's size of specified dim.
:attr:`dim` dimension orderly. The length of the list mustn't be larger than the Tensor's size of :attr:`dim` . dim (int|Tensor, optional): The dimension along which to split, it can be a scalar with type ``int`` or
dim (int32|Varible, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. The dimension along which to split. If :math:`dim < 0`, the a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``. If :math:`dim < 0`,
dimension to split along is :math:`rank(input) + dim`. Default is -1. the dimension to split along is :math:`rank(input) + dim`. Default is -1.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns: Returns:
list(Variable): The list of segmented Tensor variables. list(Tensor): The list of segmented Tensors.
Raises: Raises:
TypeError: num_or_sections is not int, list or tuple. TypeError: The data type of ``input`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: dim is not int or Variable. TypeError: ``num_or_sections`` is not int, list or tuple.
TypeError: ``dim`` is not int or Tensor. The data type of ``dim`` must be int32 or int64 when it's a Tensor.
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
# input is a variable which shape is [3, 9, 5] # input is a Tensor which shape is [3, 9, 5]
input = fluid.data( input = fluid.data(
name="input", shape=[3, 9, 5], dtype="float32") name="input", shape=[3, 9, 5], dtype="float32")
x0, x1, x2 = fluid.layers.split(input, num_or_sections=3, dim=1) out0, out1, out2 = fluid.layers.split(input, num_or_sections=3, dim=1)
# x0.shape [3, 3, 5] # out0.shape [3, 3, 5]
# x1.shape [3, 3, 5] # out1.shape [3, 3, 5]
# x2.shape [3, 3, 5] # out2.shape [3, 3, 5]
x0, x1, x2 = fluid.layers.split(input, num_or_sections=[2, 3, 4], dim=1) out0, out1, out2 = fluid.layers.split(input, num_or_sections=[2, 3, 4], dim=1)
# x0.shape [3, 2, 5] # out0.shape [3, 2, 5]
# x1.shape [3, 3, 5] # out1.shape [3, 3, 5]
# x2.shape [3, 4, 5] # out2.shape [3, 4, 5]
out0, out1, out2 = fluid.layers.split(input, num_or_sections=[2, 3, -1], dim=1)
# out0.shape [3, 2, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 4, 5]
# dim is negative, the real dim is (rank(input) + axis) which real
# value is 1.
out0, out1, out2 = fluid.layers.split(input, num_or_sections=3, dim=-2)
# out0.shape [3, 3, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 3, 5]
x0, x1, x2 = fluid.layers.split(input, num_or_sections=[2, 3, -1], dim=1)
# x0.shape [3, 2, 5]
# x1.shape [3, 3, 5]
# x2.shape [3, 4, 5]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
num = None num = None
...@@ -4861,8 +4871,6 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4861,8 +4871,6 @@ def split(input, num_or_sections, dim=-1, name=None):
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim = dim.numpy() dim = dim.numpy()
assert dim.shape == (1,
), "dim of type Variable should have shape [1]"
dim = dim[0] dim = dim[0]
dim = (len(input.shape) + dim) if dim < 0 else dim dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim) attrs += ('axis', dim)
...@@ -4873,28 +4881,29 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4873,28 +4881,29 @@ def split(input, num_or_sections, dim=-1, name=None):
elif isinstance(num_or_sections, (list, tuple)): elif isinstance(num_or_sections, (list, tuple)):
num = len(num_or_sections) num = len(num_or_sections)
if utils._contain_var(num_or_sections): if utils._contain_var(num_or_sections):
raise TypeError( for index, item in enumerate(num_or_sections):
"The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but " if isinstance(item, Variable):
"received %s, which contains Variable." % num_or_sections[index] = num_or_sections[index].numpy()[
(type(num_or_sections))) 0]
attrs += ('sections', list(num_or_sections))
else: else:
attrs += ('sections', list(num_or_sections)) attrs += ('sections', list(num_or_sections))
else: else:
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int or list in Dygraph mode, but " "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
"received %s." % (type(num_or_sections))) "received %s." % (type(num_or_sections)))
return core.ops.split(input, num, *attrs) return core.ops.split(input, num, *attrs)
if not isinstance(num_or_sections, (int, list, tuple)): check_variable_and_dtype(
raise TypeError( input, 'input',
"The type of 'num_or_sections' in split must be int, list or " ['bool', 'float16', 'float32', 'float64', 'int32', 'in64'], 'split')
"tuple, but received %s." % (type(num_or_sections))) check_type(num_or_sections, 'num_or_sections', (list, int, tuple), 'split')
if not isinstance(dim, (int, Variable)): check_type(dim, 'dim', (int, Variable), 'split')
raise TypeError( if isinstance(dim, Variable):
"The type of 'dim' in split must be int or Variable, but " check_dtype(dim.dtype, 'dim', ['int32', 'int64'], 'split')
"received %s." % (type(dim)))
helper = LayerHelper('split', **locals()) helper = LayerHelper('split', **locals())
input_shape = input.shape input_shape = input.shape
inputs = {'X': input} inputs = {'X': input}
attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0} attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0}
......
...@@ -271,6 +271,14 @@ class TestSplitOpError(unittest.TestCase): ...@@ -271,6 +271,14 @@ class TestSplitOpError(unittest.TestCase):
self.assertRaises(TypeError, test_axis_type) self.assertRaises(TypeError, test_axis_type)
# The type of axis in split_op should be int or Variable.
def test_axis_variable_type():
x9 = fluid.layers.data(shape=[4], dtype='float16', name='x9')
x10 = fluid.layers.data(shape=[1], dtype='float16', name='x10')
fluid.layers.split(input=x9, num_or_sections=2, dim=x10)
self.assertRaises(TypeError, test_axis_variable_type)
# The type of num_or_sections in split_op should be int, tuple or list. # The type of num_or_sections in split_op should be int, tuple or list.
def test_num_or_sections_type(): def test_num_or_sections_type():
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x4') x6 = fluid.layers.data(shape=[4], dtype='float16', name='x4')
...@@ -296,7 +304,7 @@ class API_TestSplit(unittest.TestCase): ...@@ -296,7 +304,7 @@ class API_TestSplit(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64')
data2 = fluid.layers.data('data2', shape=[1], dtype='int32') data2 = fluid.layers.data('data2', shape=[1], dtype='int32')
x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=data2) x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=data2)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([4, 6, 6]).astype('float64') input1 = np.random.random([4, 6, 6]).astype('float64')
...@@ -314,7 +322,7 @@ class API_TestSplit2(unittest.TestCase): ...@@ -314,7 +322,7 @@ class API_TestSplit2(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64')
x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=2) x0, x1, x2 = paddle.split(data1, num_or_sections=3, axis=2)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([4, 6, 6]).astype('float64') input1 = np.random.random([4, 6, 6]).astype('float64')
...@@ -330,7 +338,7 @@ class API_TestSplit3(unittest.TestCase): ...@@ -330,7 +338,7 @@ class API_TestSplit3(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') data = fluid.layers.data('data', shape=[-1, 10], dtype='float64')
x0, x1 = paddle.split(data, num_or_sections=(3, 7), dim=1) x0, x1 = paddle.split(data, num_or_sections=(3, 7), axis=1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([1, 10]).astype('float64') input1 = np.random.random([1, 10]).astype('float64')
...@@ -345,7 +353,7 @@ class API_TestSplit4(unittest.TestCase): ...@@ -345,7 +353,7 @@ class API_TestSplit4(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') data = fluid.layers.data('data', shape=[-1, 10], dtype='float64')
index = fluid.layers.data('index', shape=[1], dtype='int32') index = fluid.layers.data('index', shape=[1], dtype='int32')
x0, x1 = paddle.split(data, num_or_sections=(3, index), dim=1) x0, x1 = paddle.split(data, num_or_sections=(3, index), axis=1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([1, 10]).astype('float64') input1 = np.random.random([1, 10]).astype('float64')
...@@ -359,12 +367,58 @@ class API_TestSplit4(unittest.TestCase): ...@@ -359,12 +367,58 @@ class API_TestSplit4(unittest.TestCase):
class API_TestDygraphSplit(unittest.TestCase): class API_TestDygraphSplit(unittest.TestCase):
def test_out(self): def test_out1(self):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1)
x0_out = x0.numpy()
x1_out = x1.numpy()
x2_out = x2.numpy()
ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1)
self.assertTrue(np.allclose(ex_x0, x0_out))
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
def test_out2(self):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("bool")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1)
x0_out = x0.numpy()
x1_out = x1.numpy()
x2_out = x2.numpy()
ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1)
self.assertTrue(np.allclose(ex_x0, x0_out))
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
def test_out_tensor_input(self):
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
num1 = paddle.full(shape=[1], fill_value=2, dtype='int32')
x0, x1, x2 = paddle.split(
input, num_or_sections=[num1, 2, 2], axis=1)
x0_out = x0.numpy()
x1_out = x1.numpy()
x2_out = x2.numpy()
ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1)
self.assertTrue(np.allclose(ex_x0, x0_out))
self.assertTrue(np.allclose(ex_x1, x1_out))
self.assertTrue(np.allclose(ex_x2, x2_out))
def test_axis_tensor_input(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32") input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6] # input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1) input = fluid.dygraph.to_variable(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) num1 = paddle.full(shape=[1], fill_value=1, dtype='int32')
x0, x1, x2 = paddle.split(
input, num_or_sections=[2, 2, 2], axis=num1)
x0_out = x0.numpy() x0_out = x0.numpy()
x1_out = x1.numpy() x1_out = x1.numpy()
x2_out = x2.numpy() x2_out = x2.numpy()
......
...@@ -363,143 +363,66 @@ def stack(x, axis=0, out=None, name=None): ...@@ -363,143 +363,66 @@ def stack(x, axis=0, out=None, name=None):
return out return out
def split(input, num_or_sections, dim=-1, name=None): def split(x, num_or_sections, axis=0, name=None):
""" """
:alias_main: paddle.split :alias_main: paddle.split
:alias: paddle.split,paddle.tensor.split,paddle.tensor.manipulation.split :alias: paddle.tensor.split, paddle.tensor.manipulation.split
Split the input tensor into multiple sub-Tensors. Split the input tensor into multiple sub-Tensors.
Args: Args:
input (Variable): The input variable which is an N-D Tensor or LoDTensor, data type being float32, float64, int32 or int64. x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64.
num_or_sections (int|list|tuple): If :attr:`num_or_sections` is an integer, num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
then the integer indicates the number of equal sized sub-Tensors indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
that the Tensor will be divided into. If :attr:`num_or_sections` If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
is a list or tuple, the length of it indicates the number of sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly.
sub-Tensors and the elements in it indicate the sizes of sub-Tensors' The length of the list must not be larger than the ``x`` 's size of specified ``axis``.
:attr:`dim` dimension orderly. The length of the list mustn't be larger than the Tensor's size of :attr:`dim` . axis (int|Tensor, optional): The axis along which to split, it can be a scalar with type
dim (int32|Varible, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. The dimension along which to split. If :math:`dim < 0`, the ``int`` or a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``.
dimension to split along is :math:`rank(input) + dim`. Default is -1. If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns: Returns:
list(Variable): The list of segmented Tensor variables. list(Tensor): The list of segmented Tensors.
Raises: Raises:
TypeError: num_or_sections is not int, list or tuple. TypeError: The data type of ``x`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: dim is not int or Variable. TypeError: ``num_or_sections`` is not int, list or tuple.
TypeError: ``axis`` is not int or Tensor. the data type of ``axis`` must be int32 or int64 when it's a Tensor.
Example: Example:
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
with fluid.dygraph.guard():
input_1 = np.random.random([4, 6, 6]).astype("int32")
# input is a variable which shape is [4, 6, 6]
input = fluid.dygraph.to_variable(input_1)
x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) paddle.enable_imperative()
# x0.shape [4, 2, 6] # x is a Tensor which shape is [3, 9, 5]
# x1.shape [4, 2, 6] x_np = np.random.random([3, 9, 5]).astype("int32")
# x2.shape [4, 2, 6] x = paddle.imperative.to_variable(x_np)
out0, out1, out22 = paddle.split(x, num_or_sections=3, axis=1)
# out0.shape [3, 3, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 3, 5]
out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, 4], axis=1)
# out0.shape [3, 2, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 4, 5]
out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1], axis=1)
# out0.shape [3, 2, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 4, 5]
# axis is negative, the real axis is (rank(x) + axis) which real
# value is 1.
out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=-2)
# out0.shape [3, 3, 5]
# out1.shape [3, 3, 5]
# out2.shape [3, 3, 5]
""" """
if in_dygraph_mode(): return paddle.fluid.layers.split(
num = None input=x, num_or_sections=num_or_sections, dim=axis, name=name)
attrs = ()
if isinstance(dim, Variable):
dim = dim.numpy()
assert dim.shape == (1,
), "dim of type Variable should have shape [1]"
dim = dim[0]
dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim)
if isinstance(num_or_sections, int):
num = num_or_sections
attrs += ('num', num_or_sections)
elif isinstance(num_or_sections, (list, tuple)):
num = len(num_or_sections)
if utils._contain_var(num_or_sections):
raise TypeError(
"The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but "
"received %s, which contains Variable." %
(type(num_or_sections)))
else:
attrs += ('sections', list(num_or_sections))
else:
raise TypeError(
"The type of 'num_or_sections' in split must be int or list in Dygraph mode, but "
"received %s." % (type(num_or_sections)))
return core.ops.split(input, num, *attrs)
if not isinstance(num_or_sections, (int, list, tuple)):
raise TypeError(
"The type of 'num_or_sections' in split must be int, list or "
"tuple, but received %s." % (type(num_or_sections)))
if not isinstance(dim, (int, Variable)):
raise TypeError(
"The type of 'dim' in split must be int or Variable, but "
"received %s." % (type(dim)))
helper = LayerHelper('split', **locals())
input_shape = input.shape
inputs = {'X': input}
attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0}
def _get_SectionsTensorList(one_list):
tensor_list = []
unk_dim_idx = -1
for idx, dim_size in enumerate(one_list):
if isinstance(dim_size, Variable):
dim_size.stop_gradient = True
tensor_list.append(dim_size)
else:
assert (isinstance(dim_size, int))
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one value of 'num_or_section' in split can "
"be -1. But received num_or_section[%d] is also -1." %
idx)
unk_dim_idx = idx
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out)
tensor_list.append(temp_out)
return tensor_list
if isinstance(dim, Variable):
dim.stop_gradient = True
inputs['AxisTensor'] = dim
else:
dim = (len(input_shape) + dim) if dim < 0 else dim
attrs['axis'] = dim
if isinstance(num_or_sections, int):
assert num_or_sections > 1, 'num_or_sections must be more than 1.'
if isinstance(dim, int) and input_shape[dim] > 0:
assert input_shape[dim] % num_or_sections ==0, \
"The input's size along the split dimension " \
"must be evenly divisible by Attr(num_or_sections). " \
"But %d is not evenly divisible by %d. " % (num_or_sections,input_shape[dim])
num = num_or_sections
else:
if isinstance(dim, int) and input_shape[dim] > 0:
assert len(num_or_sections) <= input_shape[
dim], 'len(num_or_sections) must not be more than input.shape[dim].'
num = len(num_or_sections)
attrs['sections'] = list(
map(lambda ele: -1 if isinstance(ele, Variable) else ele,
num_or_sections))
if utils._contain_var(num_or_sections):
inputs['SectionsTensorList'] = _get_SectionsTensorList(
num_or_sections)
outs = [
helper.create_variable_for_type_inference(dtype=helper.input_dtype())
for i in range(num)
]
helper.append_op(
type='split', inputs=inputs, outputs={'Out': outs}, attrs=attrs)
return outs
def squeeze(x, axis=None, name=None): def squeeze(x, axis=None, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册