diff --git a/python/paddle/fluid/tests/unittests/test_sort_op.py b/python/paddle/fluid/tests/unittests/test_sort_op.py index 990c7a8b2dfb68ef7d2365a8f2918cd68692a216..087586aa89607a58493c2d4427cbb6d30b31f0da 100644 --- a/python/paddle/fluid/tests/unittests/test_sort_op.py +++ b/python/paddle/fluid/tests/unittests/test_sort_op.py @@ -37,7 +37,7 @@ class TestSortOnCPU(unittest.TestCase): [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], dtype='float32') - result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) + result, = exe.run(feed={'input': data}, fetch_list=[output]) np_result = np.sort(result) self.assertEqual((result == np_result).all(), True) @@ -50,7 +50,7 @@ class TestSortOnCPU(unittest.TestCase): [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], dtype='float32') - result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) + result, = exe.run(feed={'input': data}, fetch_list=[output]) np_result = np.sort(result, axis=1) self.assertEqual((result == np_result).all(), True) @@ -75,7 +75,7 @@ class TestSortDygraph(unittest.TestCase): with imperative.guard(self.place): var_x = imperative.to_variable(self.input_data) out = paddle.sort(var_x) - self.assertEqual((np.sort(self.input_data) == out[0].numpy()).all(), + self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True) def test_api_1(self): @@ -84,5 +84,4 @@ class TestSortDygraph(unittest.TestCase): out = paddle.sort(var_x, axis=-1) self.assertEqual( (np.sort( - self.input_data, axis=-1) == out[0].numpy()).all(), - True) + self.input_data, axis=-1) == out.numpy()).all(), True) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index cffaae6153cf79b90e22afa103fcd11d8bfaa402..1cb775c9d4b73beaf0f2167fe7fc9909e91d116d 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -46,8 +46,7 @@ def argsort(x, axis=-1, descending=False, name=None): :alias_main: paddle.argsort :alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort - This OP sorts the input along the given axis, and returns sorted output - data Varibale and its corresponding index Variable with the same shape as ``x``. + This OP sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Args: x(Tensor): An input N-D Tensor with type float32, float64, int16, @@ -84,26 +83,26 @@ def argsort(x, axis=-1, descending=False, name=None): out2 = paddle.argsort(x=x, axis=0) out3 = paddle.argsort(x=x, axis=1) print(out1.numpy()) - #[[[0 3 1 2] - # [0 1 2 3] - # [2 3 0 1]] + #[[[0 3 1 2] + # [0 1 2 3] + # [2 3 0 1]] # [[1 3 2 0] - # [0 1 2 3] - # [2 0 3 1]]] + # [0 1 2 3] + # [2 0 3 1]]] print(out2.numpy()) - #[[[0 1 1 1] - # [0 0 0 0] - # [1 1 1 0]] - # [[1 0 0 0] - # [1 1 1 1] - # [0 0 0 1]]] + #[[[0 1 1 1] + # [0 0 0 0] + # [1 1 1 0]] + # [[1 0 0 0] + # [1 1 1 1] + # [0 0 0 1]]] print(out3.numpy()) - #[[[1 1 1 2] - # [0 0 2 0] - # [2 2 0 1]] - # [[2 0 2 0] - # [1 1 0 2] - # [0 2 1 1]]] + #[[[1 1 1 2] + # [0 0 2 0] + # [2 2 0 1]] + # [[2 0 2 0] + # [1 1 0 2] + # [0 2 1 1]]] """ if in_dygraph_mode(): _, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending) @@ -381,8 +380,7 @@ def sort(x, axis=-1, descending=False, name=None): :alias_main: paddle.sort :alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort - This OP sorts the input along the given axis, and returns sorted output - data Tensor and its corresponding index Tensor with the same shape as ``x``. + This OP sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Args: x(Tensor): An input N-D Tensor with type float32, float64, int16, @@ -397,9 +395,7 @@ def sort(x, axis=-1, descending=False, name=None): need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - tuple: A tuple of sorted data tensor(with the same shape and data - type as ``x``) and the sorted indices(with the same shape as ``x`` - and with data type int64). + Tensor: sorted tensor(with the same shape and data type as ``x``). Examples: .. code-block:: python import paddle @@ -417,38 +413,31 @@ def sort(x, axis=-1, descending=False, name=None): out1 = paddle.sort(x=x, axis=-1) out2 = paddle.sort(x=x, axis=0) out3 = paddle.sort(x=x, axis=1) - print(out1[0].numpy()) - #[[[5. 5. 8. 9.] - # [0. 0. 1. 7.] - # [2. 4. 6. 9.]] - # [[2. 2. 4. 5.] - # [4. 7. 7. 9.] - # [0. 1. 6. 7.]]] - print(out1[1].numpy()) - #[[[0 3 1 2] - # [0 1 2 3] - # [2 3 0 1]] - # [[1 3 2 0] - # [0 1 2 3] - # [2 0 3 1]]] - print(out2[0].numpy()) + print(out1.numpy()) + #[[[5. 5. 8. 9.] + # [0. 0. 1. 7.] + # [2. 4. 6. 9.]] + # [[2. 2. 4. 5.] + # [4. 7. 7. 9.] + # [0. 1. 6. 7.]]] + print(out2.numpy()) #[[[5. 2. 4. 2.] - # [0. 0. 1. 7.] - # [1. 7. 0. 4.]] - # [[5. 8. 9. 5.] - # [4. 7. 7. 9.] - # [6. 9. 2. 6.]]] - print(out3[0].numpy()) + # [0. 0. 1. 7.] + # [1. 7. 0. 4.]] + # [[5. 8. 9. 5.] + # [4. 7. 7. 9.] + # [6. 9. 2. 6.]]] + print(out3.numpy()) #[[[0. 0. 1. 4.] - # [5. 8. 2. 5.] - # [6. 9. 9. 7.]] - # [[1. 2. 0. 2.] - # [4. 7. 4. 6.] - # [5. 7. 7. 9.]]] + # [5. 8. 2. 5.] + # [6. 9. 9. 7.]] + # [[1. 2. 0. 2.] + # [4. 7. 4. 6.] + # [5. 7. 7. 9.]]] """ if in_dygraph_mode(): - out, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending) - return out, ids + out, _ = core.ops.argsort(x, 'axis', axis, 'descending', descending) + return out helper = LayerHelper("sort", **locals()) out = helper.create_variable_for_type_inference( dtype=x.dtype, stop_gradient=False) @@ -461,7 +450,7 @@ def sort(x, axis=-1, descending=False, name=None): 'Indices': ids}, attrs={'axis': axis, 'descending': descending}) - return out, ids + return out def where(condition, x, y, name=None):