diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a6b88413668b4ec5d70c3e4b918f8dc145cb6f56..dd875d21d584d527e64af256c8660abcd96c7b8a 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -166,7 +166,7 @@ from .tensor.linalg import cross #DEFINE_ALIAS # from .tensor.manipulation import expand #DEFINE_ALIAS # from .tensor.manipulation import expand_as #DEFINE_ALIAS # from .tensor.manipulation import flatten #DEFINE_ALIAS -# from .tensor.manipulation import gather #DEFINE_ALIAS +from .tensor.manipulation import gather #DEFINE_ALIAS # from .tensor.manipulation import gather_nd #DEFINE_ALIAS # from .tensor.manipulation import reshape #DEFINE_ALIAS # from .tensor.manipulation import reverse #DEFINE_ALIAS @@ -175,14 +175,14 @@ from .tensor.linalg import cross #DEFINE_ALIAS # from .tensor.manipulation import scatter_nd #DEFINE_ALIAS # from .tensor.manipulation import shard_index #DEFINE_ALIAS # from .tensor.manipulation import slice #DEFINE_ALIAS -# from .tensor.manipulation import split #DEFINE_ALIAS -# from .tensor.manipulation import squeeze #DEFINE_ALIAS -# from .tensor.manipulation import stack #DEFINE_ALIAS +from .tensor.manipulation import split #DEFINE_ALIAS +from .tensor.manipulation import squeeze #DEFINE_ALIAS +from .tensor.manipulation import stack #DEFINE_ALIAS # from .tensor.manipulation import strided_slice #DEFINE_ALIAS # from .tensor.manipulation import transpose #DEFINE_ALIAS # from .tensor.manipulation import unique #DEFINE_ALIAS # from .tensor.manipulation import unique_with_counts #DEFINE_ALIAS -# from .tensor.manipulation import unsqueeze #DEFINE_ALIAS +from .tensor.manipulation import unsqueeze #DEFINE_ALIAS # from .tensor.manipulation import unstack #DEFINE_ALIAS from .tensor.manipulation import flip #DEFINE_ALIAS # from .tensor.manipulation import unbind #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 340be3238d1c626579abcaa3b64acb03a85e452e..f8763e731eeed3b36a6271167a57b9277479b5ba 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -17,6 +17,8 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid class TestGatherOp(OpTest): @@ -106,5 +108,35 @@ class TestCase6(TestGatherOp): self.index_type = "int32" +class API_TestGather(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64') + index = fluid.layers.data('index', shape=[-1, 1], dtype='float64') + out = paddle.gather(data1, index) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + result, = exe.run(feed={"data1": input, + "index": index_1}, + fetch_list=[out]) + expected_output = np.array([[3, 4], [5, 6]]) + self.assertTrue(np.allclose(result, expected_output)) + + +class API_TestDygraphGather(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([1, 2]) + input = fluid.dygraph.to_variable(input_1) + index = fluid.dygraph.to_variable(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([[3, 4], [5, 6]]) + self.assertTrue(np.allclose(output_np, expected_output)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 714843bb8b215117d2f38b776cca49a6eb3bae5c..2fa6c7735c5fbace555aec45614a5e690a0aeb3c 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import print_function - +import paddle import unittest import numpy as np from op_test import OpTest @@ -278,6 +278,101 @@ class TestSplitOpError(unittest.TestCase): self.assertRaises(TypeError, test_num_or_sections_type) + def test_num_or_sections_type_tensor(): + x7 = fluid.layers.data(shape=[4], dtype='float16', name='x5') + paddle.split(input=x7, num_or_sections=2.1, dim=3) + + self.assertRaises(TypeError, test_num_or_sections_type_tensor) + + def test_axis_type_tensor(): + x8 = fluid.layers.data(shape=[4], dtype='float16', name='x6') + paddle.split(input=x8, num_or_sections=2, dim=3.2) + + self.assertRaises(TypeError, test_axis_type_tensor) + + +class API_TestSplit(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') + data2 = fluid.layers.data('data2', shape=[1], dtype='int32') + x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=data2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([4, 6, 6]).astype('float64') + input2 = np.array([2]).astype('int32') + r0, r1, r2, = exe.run(feed={"data1": input1, + "data2": input2}, + fetch_list=[x0, x1, x2]) + ex_x0, ex_x1, ex_x2 = np.split(input1, 3, axis=2) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + self.assertTrue(np.allclose(ex_x2, r2)) + + +class API_TestSplit2(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[4, 6, 6], dtype='float64') + x0, x1, x2 = paddle.split(data1, num_or_sections=3, dim=2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([4, 6, 6]).astype('float64') + r0, r1, r2, = exe.run(feed={"data1": input1}, + fetch_list=[x0, x1, x2]) + ex_x0, ex_x1, ex_x2 = np.split(input1, 3, axis=2) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + self.assertTrue(np.allclose(ex_x2, r2)) + + +class API_TestSplit3(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') + x0, x1 = paddle.split(data, num_or_sections=(3, 7), dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([1, 10]).astype('float64') + r0, r1 = exe.run(feed={"data": input1}, fetch_list=[x0, x1]) + ex_x0, ex_x1 = np.split(input1, (3, ), axis=1) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + + +class API_TestSplit4(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') + index = fluid.layers.data('index', shape=[1], dtype='int32') + x0, x1 = paddle.split(data, num_or_sections=(3, index), dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([1, 10]).astype('float64') + input2 = np.array([7]).astype('int32') + r0, r1 = exe.run(feed={"data": input1, + "index": input2}, + fetch_list=[x0, x1]) + ex_x0, ex_x1 = np.split(input1, (3, ), axis=1) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + + +class API_TestDygraphSplit(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_squeeze_op.py b/python/paddle/fluid/tests/unittests/test_squeeze_op.py index 187714ce49c8088d992fb7def7f5b89f81d27752..75f474052cc94c02b3e58899f38893754299a33f 100644 --- a/python/paddle/fluid/tests/unittests/test_squeeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_squeeze_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard - +import paddle from op_test import OpTest @@ -85,5 +85,31 @@ class TestSqueezeOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.squeeze, x3, axes=0) +class API_TestSqueeze(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data( + 'data1', shape=[-1, 1, 10], dtype='float64') + result_squeeze = paddle.squeeze(data1, axes=[1]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([5, 1, 10]).astype('float64') + result, = exe.run(feed={"data1": input1}, + fetch_list=[result_squeeze]) + expected_result = np.squeeze(input1, axis=1) + self.assertTrue(np.allclose(expected_result, result)) + + +class API_TestDygraphSqueeze(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([5, 1, 10]).astype("int32") + input = fluid.dygraph.to_variable(input_1) + output = paddle.squeeze(input, axes=[1]) + out_np = output.numpy() + expected_out = np.squeeze(input_1, axis=1) + self.assertTrue(np.allclose(expected_out, out_np)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_stack_op.py b/python/paddle/fluid/tests/unittests/test_stack_op.py index 2c81e2067593ea3ffce9355b2e9cdb4a908cdb15..fd5c02c55db4c22d9edd604b7998a5405961d596 100644 --- a/python/paddle/fluid/tests/unittests/test_stack_op.py +++ b/python/paddle/fluid/tests/unittests/test_stack_op.py @@ -14,6 +14,7 @@ import numpy as np import unittest +import paddle import paddle.fluid as fluid from op_test import OpTest @@ -125,5 +126,84 @@ class TestStackAPIWithLoDTensorArray(unittest.TestCase): [self.x] * self.iter_num, axis=self.axis))) +class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase): + """ + Test stack api when the input(x) is a LoDTensorArray. + """ + + def setUp(self): + self.axis = 1 + self.iter_num = 3 + self.input_shape = [2, 3] + self.x = np.random.random(self.input_shape).astype("float32") + self.place = fluid.CUDAPlace(0) \ + if fluid.is_compiled_with_cuda() else fluid.CPUPlace() + self.set_program() + + def set_program(self): + self.program = fluid.Program() + with fluid.program_guard(self.program): + input = fluid.layers.assign(self.x) + tensor_array = fluid.layers.create_array(dtype='float32') + zero = fluid.layers.fill_constant(shape=[1], value=0, dtype="int64") + + for i in range(self.iter_num): + fluid.layers.array_write(input, zero + i, tensor_array) + + self.out_var = paddle.stack(tensor_array, axis=self.axis) + + def test_case(self): + self.assertTrue(self.out_var.shape[self.axis] == -1) + exe = fluid.Executor(self.place) + res = exe.run(self.program, fetch_list=self.out_var) + self.assertTrue( + np.array_equal( + res[0], np.stack( + [self.x] * self.iter_num, axis=self.axis))) + + +class API_test(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64') + data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64') + data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64') + result_stack = paddle.stack([data1, data2, data3], axis=0) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([1, 2]).astype('float64') + input2 = np.random.random([1, 2]).astype('float64') + input3 = np.random.random([1, 2]).astype('float64') + result, = exe.run( + feed={"data1": input1, + "data2": input2, + "data3": input3}, + fetch_list=[result_stack]) + expected_result = np.stack([input1, input2, input3], axis=0) + self.assertTrue(np.allclose(expected_result, result)) + + +class API_DygraphTest(unittest.TestCase): + def test_out(self): + data1 = np.array([[1.0, 2.0]]) + data2 = np.array([[3.0, 4.0]]) + data3 = np.array([[5.0, 6.0]]) + with fluid.dygraph.guard(): + x1 = fluid.dygraph.to_variable(data1) + x2 = fluid.dygraph.to_variable(data2) + x3 = fluid.dygraph.to_variable(data3) + result = paddle.stack([x1, x2, x3], axis=0) + result_np = result.numpy() + expected_result = np.stack([data1, data2, data3], axis=0) + self.assertTrue(np.allclose(expected_result, result_np)) + + with fluid.dygraph.guard(): + y1 = fluid.dygraph.to_variable(data1) + result = paddle.stack(y1, axis=0) + result_np_2 = result.numpy() + expected_result_2 = np.stack(data1, axis=0) + self.assertTrue(np.allclose(expected_result_2, result_np_2)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 5f80734b30400ff0c1cc9f7ab888aa9fed4ba000..1b353e1379076cc71e8013487e0b22f5bf03dc09 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -16,7 +16,8 @@ from __future__ import print_function import unittest import numpy as np - +import paddle +import paddle.fluid as fluid from op_test import OpTest @@ -76,5 +77,87 @@ class TestUnsqueezeOp4(TestUnsqueezeOp): self.new_shape = (10, 1, 1, 2, 5, 1) +class API_TestUnsqueeze(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[-1, 10], dtype='float64') + result_squeeze = paddle.unsqueeze(data1, axes=[1]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([5, 1, 10]).astype('float64') + input = np.squeeze(input1, axis=1) + result, = exe.run(feed={"data1": input}, + fetch_list=[result_squeeze]) + self.assertTrue(np.allclose(input1, result)) + + +class TestUnsqueezeOpError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + # The type of axis in split_op should be int or Variable. + def test_axes_type(): + x6 = fluid.layers.data( + shape=[-1, 10], dtype='float16', name='x3') + paddle.unsqueeze(x6, axes=3.2) + + self.assertRaises(TypeError, test_axes_type) + + +class API_TestUnsqueeze2(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.data('data1', shape=[-1, 10], dtype='float64') + data2 = fluid.data('data2', shape=[1], dtype='int32') + result_squeeze = paddle.unsqueeze(data1, axes=data2) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([5, 1, 10]).astype('float64') + input2 = np.array([1]).astype('int32') + input = np.squeeze(input1, axis=1) + result1, = exe.run(feed={"data1": input, + "data2": input2}, + fetch_list=[result_squeeze]) + self.assertTrue(np.allclose(input1, result1)) + + +class API_TestUnsqueeze3(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.data('data1', shape=[-1, 10], dtype='float64') + data2 = fluid.data('data2', shape=[1], dtype='int32') + result_squeeze = paddle.unsqueeze(data1, axes=[data2, 3]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([5, 1, 10, 1]).astype('float64') + input2 = np.array([1]).astype('int32') + input = np.squeeze(input1) + result1, = exe.run(feed={"data1": input, + "data2": input2}, + fetch_list=[result_squeeze]) + self.assertTrue(np.allclose(input1, result1)) + + +class API_TestDyUnsqueeze(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([5, 1, 10]).astype("int32") + input1 = np.squeeze(input_1, axis=1) + input = fluid.dygraph.to_variable(input_1) + output = paddle.unsqueeze(input, axes=[1]) + out_np = output.numpy() + self.assertTrue(np.allclose(input1, out_np)) + + +class API_TestDyUnsqueeze2(unittest.TestCase): + def test_out(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([5, 1, 10]).astype("int32") + input1 = np.squeeze(input_1, axis=1) + input = fluid.dygraph.to_variable(input_1) + output = paddle.unsqueeze(input, axes=1) + out_np = output.numpy() + self.assertTrue(np.allclose(input1, out_np)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 922acb1e631fcea43dac3b31350b16359c225d70..4ce7725b3a34764ac0a2a6df2651b7427f8f6614 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -145,7 +145,7 @@ from .linalg import bmm #DEFINE_ALIAS # from .manipulation import expand #DEFINE_ALIAS # from .manipulation import expand_as #DEFINE_ALIAS # from .manipulation import flatten #DEFINE_ALIAS -# from .manipulation import gather #DEFINE_ALIAS +from .manipulation import gather #DEFINE_ALIAS # from .manipulation import gather_nd #DEFINE_ALIAS # from .manipulation import reshape #DEFINE_ALIAS # from .manipulation import reverse #DEFINE_ALIAS @@ -154,14 +154,14 @@ from .linalg import bmm #DEFINE_ALIAS # from .manipulation import scatter_nd #DEFINE_ALIAS # from .manipulation import shard_index #DEFINE_ALIAS # from .manipulation import slice #DEFINE_ALIAS -# from .manipulation import split #DEFINE_ALIAS -# from .manipulation import squeeze #DEFINE_ALIAS -# from .manipulation import stack #DEFINE_ALIAS +from .manipulation import split #DEFINE_ALIAS +from .manipulation import squeeze #DEFINE_ALIAS +from .manipulation import stack #DEFINE_ALIAS # from .manipulation import strided_slice #DEFINE_ALIAS # from .manipulation import transpose #DEFINE_ALIAS # from .manipulation import unique #DEFINE_ALIAS # from .manipulation import unique_with_counts #DEFINE_ALIAS -# from .manipulation import unsqueeze #DEFINE_ALIAS +from .manipulation import unsqueeze #DEFINE_ALIAS # from .manipulation import unstack #DEFINE_ALIAS from .manipulation import flip #DEFINE_ALIAS # from .manipulation import unbind #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 48f694f400e7d3661face49e4d848d7abd9cf6f2..4c982c4fabb46ccd181eef47229b8f76802a2a1e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -18,7 +18,8 @@ from ..fluid.layers import core, reshape from ..fluid.layer_helper import LayerHelper from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype - +from ..fluid.layers.tensor import fill_constant +from ..fluid.layers import utils # TODO: define functions to manipulate a tensor __all__ = [ # 'cast', @@ -26,7 +27,7 @@ __all__ = [ # 'expand', # 'expand_as', # 'flatten', - # 'gather', + 'gather', # 'gather_nd', # 'reshape', # 'reverse', @@ -35,14 +36,14 @@ __all__ = [ # 'scatter_nd', # 'shard_index', # 'slice', - # 'split', - # 'squeeze', - # 'stack', + 'split', + 'squeeze', + 'stack', # 'strided_slice', # 'transpose', # 'unique', # 'unique_with_counts', - # 'unsqueeze', + 'unsqueeze', # 'unstack', 'flip', # 'unbind', @@ -169,3 +170,476 @@ def roll(input, shifts, dims=None): 'shifts': shifts}) out = reshape(out, shape=origin_shape, inplace=True) return out + + +def stack(x, axis=0, out=None, name=None): + """ + + This OP stacks all the inputs :code:`x` along axis. + + .. code-block:: text + + Case 1: + + Input: + x[0].shape = [1, 2] + x[0].data = [ [1.0 , 2.0 ] ] + x[1].shape = [1, 2] + x[1].data = [ [3.0 , 4.0 ] ] + x[2].shape = [1, 2] + x[2].data = [ [5.0 , 6.0 ] ] + + Attrs: + axis = 0 + + Output: + Out.dims = [3, 1, 2] + Out.data =[ [ [1.0, 2.0] ], + [ [3.0, 4.0] ], + [ [5.0, 6.0] ] ] + + + Case 2: + + + Input: + x[0].shape = [1, 2] + x[0].data = [ [1.0 , 2.0 ] ] + x[1].shape = [1, 2] + x[1].data = [ [3.0 , 4.0 ] ] + x[2].shape = [1, 2] + x[2].data = [ [5.0 , 6.0 ] ] + + + Attrs: + axis = 1 or axis = -2 + + Output: + Out.shape = [1, 3, 2] + Out.data =[ [ [1.0, 2.0] + [3.0, 4.0] + [5.0, 6.0] ] ] + + Args: + x (Variable|list(Variable)): Input :code:`x` can be a single Tensor, a :code:`list` of Tensors. + If :code:`x` is a :code:`list`, the shapes of all these Tensors + must be the same. Supposing input is N dims + Tensors :math:`[d_0, d_1, ..., d_{n-1}]`, the output is N+1 dims + Tensor :math:`[d_0, d_1, d_{axis-1}, len(x), d_{axis}, ..., d_{n-1}]`. + Support data types: float32, float64, int32, int64. + axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is :math:`[-(R+1), R+1)`. + R is the first tensor of inputs. If ``axis`` < 0, :math:`axis=axis+rank(x[0])+1`. + The default value of axis is 0. + + Returns: + Variable: The stacked Tensor, has same data type with input Tensors. Output dim is :math:`rank(x[0])+1`. + + Example: + .. code-block:: python + import numpy as np + import paddle + import paddle.fluid as fluid + + data1 = np.array([[1.0, 2.0]]) + data2 = np.array([[3.0, 4.0]]) + data3 = np.array([[5.0, 6.0]]) + with fluid.dygraph.guard(): + x1 = fluid.dygraph.to_variable(data1) + x2 = fluid.dygraph.to_variable(data2) + x3 = fluid.dygraph.to_variable(data3) + result = paddle.stack([x1, x2, x3], axis=0) + # result shape: [3, 1, 2] + # result value: [[[1.0, 2.0]], + # [[3.0, 4.0]], + # [[5.0, 6.0]]] + """ + + helper = LayerHelper('stack', **locals()) + axis = 0 if axis is None else axis + + if not isinstance(x, list) and not isinstance(x, tuple): + x = [x] + out = helper.create_variable_for_type_inference(x[0].dtype) + if not in_dygraph_mode() and \ + x[0].desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY: + assert len(x) == 1, "If the elements of 'x' in stack are Variable(LoDTensorArray), " \ + "number of the elements must be 1, but received %s." % len(x) + out_index = helper.create_variable_for_type_inference(dtype="int32") + helper.append_op( + type='tensor_array_to_tensor', + inputs={'X': x[0]}, + outputs={'Out': [out], + 'OutIndex': [out_index]}, + attrs={'axis': axis, + 'use_stack': True}) + else: + helper.append_op( + type='stack', + inputs={'X': x}, + outputs={'Y': out}, + attrs={'axis': axis}) + + return out + + +def split(input, num_or_sections, dim=-1, name=None): + """ + Split the input tensor into multiple sub-Tensors. + Args: + input (Variable): The input variable which is an N-D Tensor or LoDTensor, data type being float32, float64, int32 or int64. + num_or_sections (int|list|tuple): If :attr:`num_or_sections` is an integer, + then the integer indicates the number of equal sized sub-Tensors + that the Tensor will be divided into. If :attr:`num_or_sections` + is a list or tuple, the length of it indicates the number of + sub-Tensors and the elements in it indicate the sizes of sub-Tensors' + :attr:`dim` dimension orderly. The length of the list mustn't be larger than the Tensor's size of :attr:`dim` . + dim (int32|Varible, optional): A scalar with type ``int32`` or a ``Tensor`` with shape [1] and type ``int32``. The dimension along which to split. If :math:`dim < 0`, the + dimension to split along is :math:`rank(input) + dim`. Default is -1. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . + Returns: + list(Variable): The list of segmented Tensor variables. + Raises: + TypeError: num_or_sections is not int, list or tuple. + TypeError: dim is not int or Variable. + Example: + .. code-block:: python + import numpy as np + import paddle + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + + x0, x1, x2 = paddle.split(input, num_or_sections=3, dim=1) + # x0.shape [4, 2, 6] + # x1.shape [4, 2, 6] + # x2.shape [4, 2, 6] + """ + if in_dygraph_mode(): + num = None + attrs = () + + if isinstance(dim, Variable): + dim = dim.numpy() + assert dim.shape == (1, + ), "dim of type Variable should have shape [1]" + dim = dim[0] + dim = (len(input.shape) + dim) if dim < 0 else dim + attrs += ('axis', dim) + + if isinstance(num_or_sections, int): + num = num_or_sections + attrs += ('num', num_or_sections) + elif isinstance(num_or_sections, (list, tuple)): + num = len(num_or_sections) + if utils._contain_var(num_or_sections): + raise TypeError( + "The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but " + "received %s, which contains Variable." % + (type(num_or_sections))) + else: + attrs += ('sections', list(num_or_sections)) + else: + raise TypeError( + "The type of 'num_or_sections' in split must be int or list in Dygraph mode, but " + "received %s." % (type(num_or_sections))) + return core.ops.split(input, num, *attrs) + + if not isinstance(num_or_sections, (int, list, tuple)): + raise TypeError( + "The type of 'num_or_sections' in split must be int, list or " + "tuple, but received %s." % (type(num_or_sections))) + if not isinstance(dim, (int, Variable)): + raise TypeError( + "The type of 'dim' in split must be int or Variable, but " + "received %s." % (type(dim))) + + helper = LayerHelper('split', **locals()) + input_shape = input.shape + inputs = {'X': input} + attrs = {'num': num_or_sections if isinstance(num_or_sections, int) else 0} + + def _get_SectionsTensorList(one_list): + tensor_list = [] + unk_dim_idx = -1 + for idx, dim_size in enumerate(one_list): + if isinstance(dim_size, Variable): + dim_size.stop_gradient = True + tensor_list.append(dim_size) + else: + assert (isinstance(dim_size, int)) + if dim_size == -1: + assert unk_dim_idx == -1, ( + "Only one value of 'num_or_section' in split can " + "be -1. But received num_or_section[%d] is also -1." % + idx) + unk_dim_idx = idx + temp_out = helper.create_variable_for_type_inference('int32') + fill_constant( + [1], 'int32', dim_size, force_cpu=True, out=temp_out) + tensor_list.append(temp_out) + return tensor_list + + if isinstance(dim, Variable): + dim.stop_gradient = True + inputs['AxisTensor'] = dim + else: + dim = (len(input_shape) + dim) if dim < 0 else dim + attrs['axis'] = dim + + if isinstance(num_or_sections, int): + assert num_or_sections > 1, 'num_or_sections must be more than 1.' + if isinstance(dim, int) and input_shape[dim] > 0: + assert input_shape[dim] % num_or_sections ==0, \ + "The input's size along the split dimension " \ + "must be evenly divisible by Attr(num_or_sections). " \ + "But %d is not evenly divisible by %d. " % (num_or_sections,input_shape[dim]) + num = num_or_sections + else: + if isinstance(dim, int) and input_shape[dim] > 0: + assert len(num_or_sections) <= input_shape[ + dim], 'len(num_or_sections) must not be more than input.shape[dim].' + num = len(num_or_sections) + attrs['sections'] = list( + map(lambda ele: -1 if isinstance(ele, Variable) else ele, + num_or_sections)) + if utils._contain_var(num_or_sections): + inputs['SectionsTensorList'] = _get_SectionsTensorList( + num_or_sections) + + outs = [ + helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + for i in range(num) + ] + helper.append_op( + type='split', inputs=inputs, outputs={'Out': outs}, attrs=attrs) + return outs + + +def squeeze(input, axes, out=None, name=None): + """ + This OP will squeeze single-dimensional entries of input tensor's shape. If axes is provided, will + remove the dims by axes, the dims selected by axes should be one. If not provide axes, all dims equal + to one will be deleted. + + + .. code-block:: text + + Case1: + + Input: + X.shape = (1, 3, 1, 5) + axes = [0] + Output: + Out.shape = (3, 1, 5) + + Case2: + + Input: + X.shape = (1, 3, 1, 5) + axes = [] + Output: + Out.shape = (3, 5) + + Case3: + + Input: + X.shape = [1,3,1,5] + axes = [-2] + Output: + Out.shape = [1,3,5] + + Args: + input (Variable): The input Tensor. Support data type: float32, float64, int8, int32, int64. + axes (list): One integer or List of integers, indicating the dimensions to be squeezed. + Axes range is :math:`[-rank(input), rank(input))`. + If axes is negative, :math:`axes=axes+rank(input)`. + name (str, optional): Please refer to :ref:`api_guide_Name`, Default None. + + Returns: + Variable: Output squeezed Tensor. Data type is same as input Tensor. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + input_1 = np.random.random([5, 1, 10]).astype("int32") + # input is a variable which shape is [5, 1, 10] + input = fluid.dygraph.to_variable(input_1) + + output = paddle.squeeze(input, axes=[1]) + # output.shape [5, 10] + + """ + + helper = LayerHelper("squeeze", **locals()) + check_variable_and_dtype(input, 'input', + ['float32', 'float64', 'int8', 'int32', 'int64'], + 'squeeze') + check_type(axes, 'axes', list, 'squeeze') + out = helper.create_variable_for_type_inference(dtype=input.dtype) + x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type="squeeze2", + inputs={"X": input}, + attrs={"axes": axes}, + outputs={"Out": out, + "XShape": x_shape}) + + return out + + +def unsqueeze(input, axes, out=None, name=None): + """ + Insert single-dimensional entries to the shape of a Tensor. Takes one + required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + + .. code-block:: text + + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueezed tensor with axes=[0, 4] has shape [1, 3, 4, 5, 1]. + + Args: + input (Variable): The input Tensor to be unsqueezed. It is a N-D Tensor of data types float32, float64, int32. + axes (int|list|tuple|Variable): Indicates the dimensions to be inserted. The data type is ``int32`` . If ``axes`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``axes`` is an Variable, it should be an 1-D Tensor . + name (str|None): Name for this layer. + + Returns: + Variable: Output unsqueezed Tensor, with data type being float32, float64, int32, int64. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + input_1 = np.random.random([5, 10]).astype("int32") + # input is a variable which shape is [5, 10] + input = fluid.dygraph.to_variable(input_1) + + output = paddle.unsqueeze(input, axes=[1]) + # output.shape [5, 1, 10] + """ + if not isinstance(axes, (int, list, tuple, Variable)): + raise TypeError( + "The type of 'axes' in unsqueeze must be int, list, tuple or Variable, but " + "received %s." % (type(axes))) + helper = LayerHelper("unsqueeze2", **locals()) + inputs = {"X": input} + attrs = {} + + def _to_Variable_list(one_list): + Variable_list = [] + for ele in one_list: + if isinstance(ele, Variable): + ele.stop_gradient = True + Variable_list.append(ele) + else: + assert (isinstance(ele, int)) + temp_out = helper.create_variable_for_type_inference('int32') + fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out) + Variable_list.append(temp_out) + return Variable_list + + if isinstance(axes, int): + axes = [axes] + if isinstance(axes, Variable): + axes.stop_gradient = True + inputs["AxesTensor"] = axes + elif isinstance(axes, (list, tuple)): + contain_var = not all(not isinstance(ele, Variable) for ele in axes) + if contain_var: + inputs["AxesTensorList"] = _to_Variable_list(axes) + else: + attrs["axes"] = axes + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + x_shape = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type="unsqueeze2", + inputs=inputs, + attrs=attrs, + outputs={"Out": out, + "XShape": x_shape}) + + return out + + +def gather(input, index, overwrite=True): + """ + **Gather Layer** + + Output is obtained by gathering entries of the outer-most dimension + of X indexed by `index` and concatenate them together. + + .. math:: + + Out = X[Index] + + + .. code-block:: text + + + Given: + + X = [[1, 2], + [3, 4], + [5, 6]] + + Index = [1, 2] + + Then: + + Out = [[3, 4], + [5, 6]] + Args: + input (Variable): The source input tensor with rank>=1. Supported data type is + int32, int64, float32, float64 and uint8 (only for CPU), + float16 (only for GPU). + index (Variable): The index input tensor with rank=1. Data type is int32 or int64. + overwrite (bool, optional): The mode that updating the grad when has same index. + If True, use the overwrite mode to update the grad of the same index, + if False, use the accumulate mode to update the grad of the same index. + Default value is True. + + + + Returns: + output (Variable): The output is a tensor with the same rank as input. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + + with fluid.dygraph.guard(): + input_1 = np.array([[1,2],[3,4],[5,6]]) + index_1 = np.array([0,1]) + input = fluid.dygraph.to_variable(input_1) + index = fluid.dygraph.to_variable(index_1) + output = paddle.gather(input, index) + # expected output: [[1,2],[3,4]] + """ + helper = LayerHelper('gather', **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="gather", + inputs={"X": input, + "Index": index}, + outputs={"Out": out}, + attrs={'overwrite': overwrite}) + return out