diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index c22c8a18ca63a05265ac6991cf0e0cbd9e7ea5ed..1427bd04d3442be26be931ca31bf358ebd23efae 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -45,7 +45,7 @@ class GatherNdOp : public framework::OperatorWithKernel { index_dims[index_dims_size - 1], x_dims_size, platform::errors::InvalidArgument( "Input(Index).shape[-1] should be no greater than Input(X).rank")); - PADDLE_ENFORCE_GE(index_dims_size, 2UL, + PADDLE_ENFORCE_GE(index_dims_size, 1UL, platform::errors::InvalidArgument( "The rank of Input(Index) should be greater than 1")); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 20ccd91666190454e4711fd07bfe259c518f01d7..81dfc34f3d54c41b0fae2791ffe2810e0b77bbec 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -8323,14 +8323,18 @@ def gather_nd(input, index, name=None): = [23] Args: - input (Variable): The source input. Its dtype should be int32, int64, float32, float64. - index (Variable): The index input with rank > 1, index.shape[-1] <= input.rank. - Its dtype should be int32, int64. - name (str|None): A name for this layer(optional). If set None, the - layer will be named automatically. + input (Tensor): The source input. Its dtype should be bool, float32, float64, int32, int64. + index (Tensor): The index input with rank > 1, index.shape[-1] <= input.rank. + Its dtype should be int32, int64. + name(str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - output (Variable): A tensor with the shape index.shape[:-1] + input.shape[index.shape[-1]:] + output (Tensor): A tensor with the shape index.shape[:-1] + input.shape[index.shape[-1]:] + + Raises: + TypeError: ``input`` must be a Tensor and the data type of ``input`` must be one of float32, float64, int32 and int64. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be one of int32 and int64. Examples: @@ -8342,6 +8346,12 @@ def gather_nd(input, index, name=None): output = fluid.layers.gather_nd(x, index) """ + if in_dygraph_mode(): + return core.ops.gather_nd(input, index) + check_variable_and_dtype(input, 'input', + ['bool', 'float32', 'float64', 'int32', 'int64'], + 'gather_np') + check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np') helper = LayerHelper('gather_nd', **locals()) dtype = helper.input_dtype() output = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index 892f63bf15b742c51ddbc15262f888e43cdd03f3..bd934c76ebfa2ed7c9b11223b34c812e605ebe18 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -18,12 +18,11 @@ import unittest import numpy as np from op_test import OpTest import paddle.fluid as fluid +import paddle class TestGatherNdOpWithEmptyIndex(OpTest): - """ - Index has empty element, which means copy entire tensor - """ + #Index has empty element, which means copy entire tensor def setUp(self): self.op_type = "gather_nd" @@ -40,10 +39,22 @@ class TestGatherNdOpWithEmptyIndex(OpTest): self.check_grad(['X'], 'Out') +class TestGatherNdOpWithIndex1(OpTest): + def setUp(self): + self.op_type = "gather_nd" + xnp = np.random.random((5, 20)).astype("float64") + self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestGatherNdOpWithLowIndex(OpTest): - """ - Index has low rank, X has high rank - """ + #Index has low rank, X has high rank def setUp(self): self.op_type = "gather_nd" @@ -61,10 +72,27 @@ class TestGatherNdOpWithLowIndex(OpTest): self.check_grad(['X'], 'Out') +class TestGatherNdOpIndex1(OpTest): + #Index has low rank, X has high rank + + def setUp(self): + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") + index = np.array([1, 2]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + + self.outputs = {'Out': xnp[tuple(index.T)]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestGatherNdOpWithSameIndexAsX(OpTest): - """ - Index has same rank as X's rank - """ + #Index has same rank as X's rank def setUp(self): self.op_type = "gather_nd" @@ -82,9 +110,7 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): class TestGatherNdOpWithHighRankSame(OpTest): - """ - Both Index and X have high rank, and Rank(Index) = Rank(X) - """ + #Both Index and X have high rank, and Rank(Index) = Rank(X) def setUp(self): self.op_type = "gather_nd" @@ -103,9 +129,7 @@ class TestGatherNdOpWithHighRankSame(OpTest): class TestGatherNdOpWithHighRankDiff(OpTest): - """ - Both Index and X have high rank, and Rank(Index) < Rank(X) - """ + #Both Index and X have high rank, and Rank(Index) < Rank(X) def setUp(self): self.op_type = "gather_nd" @@ -162,5 +186,63 @@ class TestGatherNdOpRaise(unittest.TestCase): self.assertRaises(IndexError, check_raise_is_test) +class TestGatherNdError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.data(shape=shape, dtype='float32', name='x') + index = paddle.data(shape=shape, dtype='bool', name='index') + index_float = paddle.data( + shape=shape, dtype='float32', name='index_float') + np_x = np.random.random(shape).astype('float32') + np_index = np.array(np.random.randint(2, size=shape, dtype=bool)) + + def test_x_type(): + paddle.gather_nd(np_x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.gather_nd(x, np_index) + + self.assertRaises(TypeError, test_index_type) + + def test_index_dtype(): + paddle.gather_nd(x, index_float) + + self.assertRaises(TypeError, test_index_dtype) + + +class TestGatherNdAPI2(unittest.TestCase): + def test_static(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.layers.data('data1', shape=[-1, 2], dtype='float64') + index = fluid.layers.data('index', shape=[-1, 1], dtype='int32') + out = paddle.gather_nd(data1, index) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([[1]]) + result, = exe.run(feed={"data1": input, + "index": index_1}, + fetch_list=[out]) + expected_output = np.array([[3, 4]]) + self.assertTrue(np.allclose(result, expected_output)) + + def test_imperative(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([[1]]) + input = fluid.dygraph.to_variable(input_1) + index = fluid.dygraph.to_variable(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([3, 4]) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9f419c6a454890b386a251943984ef668e46266e..dd161d8c5f8df232e482ead8c88b731ec40b03c8 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1222,3 +1222,88 @@ def reshape(x, shape, name=None): # the shape of out_2 is [8, 6]. """ return paddle.fluid.layers.reshape(x=x, shape=shape, name=name) + + +def gather_nd(x, index, name=None): + """ + **Gather Nd Layer** + + This function is actually a high-dimensional extension of :code:`gather` + and supports for simultaneous indexing by multiple axes. :attr:`index` is a + K-dimensional integer tensor, which is regarded as a (K-1)-dimensional + tensor of :attr:`index` into :attr:`input`, where each element defines + a slice of params: + + .. math:: + + output[(i_0, ..., i_{K-2})] = input[index[(i_0, ..., i_{K-2})]] + + Obviously, :code:`index.shape[-1] <= input.rank` . And, the output tensor has + shape :code:`index.shape[:-1] + input.shape[index.shape[-1]:]` . + + .. code-block:: text + + Given: + input = [[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]] + input.shape = (2, 3, 4) + + * Case 1: + index = [[1]] + + gather_nd(input, index) + = [input[1, :, :]] + = [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]] + + * Case 2: + index = [[0,2]] + + gather_nd(input, index) + = [input[0, 2, :]] + = [8, 9, 10, 11] + + * Case 3: + index = [[1, 2, 3]] + + gather_nd(input, index) + = [input[1, 2, 3]] + = [23] + + Args: + x (Tensor): The input Tensor which it's data type should be bool, float32, float64, int32, int64. + index (Tensor): The index input with rank > 1, index.shape[-1] <= input.rank. + Its dtype should be int32, int64. + name(str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + output (Tensor): A tensor with the shape index.shape[:-1] + input.shape[index.shape[-1]:] + + Raises: + TypeError: ``x`` must be a Tensor and the data type of ``x`` must be one of float32, float64, int32 and int64. + TypeError: ``index`` must be a Tensor and the data type of ``index`` must be one of int32 and int64. + + Examples: + + .. code-block:: python + import paddle + import numpy as np + + paddle.disable_static() + np_x = np.array([[[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11, 12]]]) + np_index = [[0, 1]] + x = paddle.to_tensor(np_x) + index = paddle.to_tensor(np_index) + + output = paddle.gather_nd(x, index) #[[3, 4]] + + """ + + return paddle.fluid.layers.gather_nd(input=x, index=index, name=name)