From a43b0d155de115e3bfb551c6b86321fc01ad6a8f Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 5 Aug 2020 21:55:20 +0800 Subject: [PATCH] Update the code for the sort api update the sort api, delete unused ouput index tensor --- .../fluid/tests/unittests/test_sort_op.py | 9 +- python/paddle/tensor/search.py | 95 ++++++++----------- 2 files changed, 46 insertions(+), 58 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sort_op.py b/python/paddle/fluid/tests/unittests/test_sort_op.py index 990c7a8b2df..087586aa896 100644 --- a/python/paddle/fluid/tests/unittests/test_sort_op.py +++ b/python/paddle/fluid/tests/unittests/test_sort_op.py @@ -37,7 +37,7 @@ class TestSortOnCPU(unittest.TestCase): [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], dtype='float32') - result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) + result, = exe.run(feed={'input': data}, fetch_list=[output]) np_result = np.sort(result) self.assertEqual((result == np_result).all(), True) @@ -50,7 +50,7 @@ class TestSortOnCPU(unittest.TestCase): [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], dtype='float32') - result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) + result, = exe.run(feed={'input': data}, fetch_list=[output]) np_result = np.sort(result, axis=1) self.assertEqual((result == np_result).all(), True) @@ -75,7 +75,7 @@ class TestSortDygraph(unittest.TestCase): with imperative.guard(self.place): var_x = imperative.to_variable(self.input_data) out = paddle.sort(var_x) - self.assertEqual((np.sort(self.input_data) == out[0].numpy()).all(), + self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True) def test_api_1(self): @@ -84,5 +84,4 @@ class TestSortDygraph(unittest.TestCase): out = paddle.sort(var_x, axis=-1) self.assertEqual( (np.sort( - self.input_data, axis=-1) == out[0].numpy()).all(), - True) + self.input_data, axis=-1) == out.numpy()).all(), True) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index cffaae6153c..1cb775c9d4b 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -46,8 +46,7 @@ def argsort(x, axis=-1, descending=False, name=None): :alias_main: paddle.argsort :alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort - This OP sorts the input along the given axis, and returns sorted output - data Varibale and its corresponding index Variable with the same shape as ``x``. + This OP sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Args: x(Tensor): An input N-D Tensor with type float32, float64, int16, @@ -84,26 +83,26 @@ def argsort(x, axis=-1, descending=False, name=None): out2 = paddle.argsort(x=x, axis=0) out3 = paddle.argsort(x=x, axis=1) print(out1.numpy()) - #[[[0 3 1 2] - # [0 1 2 3] - # [2 3 0 1]] + #[[[0 3 1 2] + # [0 1 2 3] + # [2 3 0 1]] # [[1 3 2 0] - # [0 1 2 3] - # [2 0 3 1]]] + # [0 1 2 3] + # [2 0 3 1]]] print(out2.numpy()) - #[[[0 1 1 1] - # [0 0 0 0] - # [1 1 1 0]] - # [[1 0 0 0] - # [1 1 1 1] - # [0 0 0 1]]] + #[[[0 1 1 1] + # [0 0 0 0] + # [1 1 1 0]] + # [[1 0 0 0] + # [1 1 1 1] + # [0 0 0 1]]] print(out3.numpy()) - #[[[1 1 1 2] - # [0 0 2 0] - # [2 2 0 1]] - # [[2 0 2 0] - # [1 1 0 2] - # [0 2 1 1]]] + #[[[1 1 1 2] + # [0 0 2 0] + # [2 2 0 1]] + # [[2 0 2 0] + # [1 1 0 2] + # [0 2 1 1]]] """ if in_dygraph_mode(): _, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending) @@ -381,8 +380,7 @@ def sort(x, axis=-1, descending=False, name=None): :alias_main: paddle.sort :alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort - This OP sorts the input along the given axis, and returns sorted output - data Tensor and its corresponding index Tensor with the same shape as ``x``. + This OP sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Args: x(Tensor): An input N-D Tensor with type float32, float64, int16, @@ -397,9 +395,7 @@ def sort(x, axis=-1, descending=False, name=None): need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - tuple: A tuple of sorted data tensor(with the same shape and data - type as ``x``) and the sorted indices(with the same shape as ``x`` - and with data type int64). + Tensor: sorted tensor(with the same shape and data type as ``x``). Examples: .. code-block:: python import paddle @@ -417,38 +413,31 @@ def sort(x, axis=-1, descending=False, name=None): out1 = paddle.sort(x=x, axis=-1) out2 = paddle.sort(x=x, axis=0) out3 = paddle.sort(x=x, axis=1) - print(out1[0].numpy()) - #[[[5. 5. 8. 9.] - # [0. 0. 1. 7.] - # [2. 4. 6. 9.]] - # [[2. 2. 4. 5.] - # [4. 7. 7. 9.] - # [0. 1. 6. 7.]]] - print(out1[1].numpy()) - #[[[0 3 1 2] - # [0 1 2 3] - # [2 3 0 1]] - # [[1 3 2 0] - # [0 1 2 3] - # [2 0 3 1]]] - print(out2[0].numpy()) + print(out1.numpy()) + #[[[5. 5. 8. 9.] + # [0. 0. 1. 7.] + # [2. 4. 6. 9.]] + # [[2. 2. 4. 5.] + # [4. 7. 7. 9.] + # [0. 1. 6. 7.]]] + print(out2.numpy()) #[[[5. 2. 4. 2.] - # [0. 0. 1. 7.] - # [1. 7. 0. 4.]] - # [[5. 8. 9. 5.] - # [4. 7. 7. 9.] - # [6. 9. 2. 6.]]] - print(out3[0].numpy()) + # [0. 0. 1. 7.] + # [1. 7. 0. 4.]] + # [[5. 8. 9. 5.] + # [4. 7. 7. 9.] + # [6. 9. 2. 6.]]] + print(out3.numpy()) #[[[0. 0. 1. 4.] - # [5. 8. 2. 5.] - # [6. 9. 9. 7.]] - # [[1. 2. 0. 2.] - # [4. 7. 4. 6.] - # [5. 7. 7. 9.]]] + # [5. 8. 2. 5.] + # [6. 9. 9. 7.]] + # [[1. 2. 0. 2.] + # [4. 7. 4. 6.] + # [5. 7. 7. 9.]]] """ if in_dygraph_mode(): - out, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending) - return out, ids + out, _ = core.ops.argsort(x, 'axis', axis, 'descending', descending) + return out helper = LayerHelper("sort", **locals()) out = helper.create_variable_for_type_inference( dtype=x.dtype, stop_gradient=False) @@ -461,7 +450,7 @@ def sort(x, axis=-1, descending=False, name=None): 'Indices': ids}, attrs={'axis': axis, 'descending': descending}) - return out, ids + return out def where(condition, x, y, name=None): -- GitLab