未验证 提交 a43b0d15 编写于 作者: W wawltor 提交者: GitHub

Update the code for the sort api

update the sort api, delete unused ouput index tensor
上级 dca56f47
...@@ -37,7 +37,7 @@ class TestSortOnCPU(unittest.TestCase): ...@@ -37,7 +37,7 @@ class TestSortOnCPU(unittest.TestCase):
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32') dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) result, = exe.run(feed={'input': data}, fetch_list=[output])
np_result = np.sort(result) np_result = np.sort(result)
self.assertEqual((result == np_result).all(), True) self.assertEqual((result == np_result).all(), True)
...@@ -50,7 +50,7 @@ class TestSortOnCPU(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestSortOnCPU(unittest.TestCase):
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], [[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]], [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32') dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]]) result, = exe.run(feed={'input': data}, fetch_list=[output])
np_result = np.sort(result, axis=1) np_result = np.sort(result, axis=1)
self.assertEqual((result == np_result).all(), True) self.assertEqual((result == np_result).all(), True)
...@@ -75,7 +75,7 @@ class TestSortDygraph(unittest.TestCase): ...@@ -75,7 +75,7 @@ class TestSortDygraph(unittest.TestCase):
with imperative.guard(self.place): with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data) var_x = imperative.to_variable(self.input_data)
out = paddle.sort(var_x) out = paddle.sort(var_x)
self.assertEqual((np.sort(self.input_data) == out[0].numpy()).all(), self.assertEqual((np.sort(self.input_data) == out.numpy()).all(),
True) True)
def test_api_1(self): def test_api_1(self):
...@@ -84,5 +84,4 @@ class TestSortDygraph(unittest.TestCase): ...@@ -84,5 +84,4 @@ class TestSortDygraph(unittest.TestCase):
out = paddle.sort(var_x, axis=-1) out = paddle.sort(var_x, axis=-1)
self.assertEqual( self.assertEqual(
(np.sort( (np.sort(
self.input_data, axis=-1) == out[0].numpy()).all(), self.input_data, axis=-1) == out.numpy()).all(), True)
True)
...@@ -46,8 +46,7 @@ def argsort(x, axis=-1, descending=False, name=None): ...@@ -46,8 +46,7 @@ def argsort(x, axis=-1, descending=False, name=None):
:alias_main: paddle.argsort :alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort :alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
data Varibale and its corresponding index Variable with the same shape as ``x``.
Args: Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16, x(Tensor): An input N-D Tensor with type float32, float64, int16,
...@@ -381,8 +380,7 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -381,8 +380,7 @@ def sort(x, axis=-1, descending=False, name=None):
:alias_main: paddle.sort :alias_main: paddle.sort
:alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort :alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
data Tensor and its corresponding index Tensor with the same shape as ``x``.
Args: Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16, x(Tensor): An input N-D Tensor with type float32, float64, int16,
...@@ -397,9 +395,7 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -397,9 +395,7 @@ def sort(x, axis=-1, descending=False, name=None):
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
tuple: A tuple of sorted data tensor(with the same shape and data Tensor: sorted tensor(with the same shape and data type as ``x``).
type as ``x``) and the sorted indices(with the same shape as ``x``
and with data type int64).
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -417,28 +413,21 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -417,28 +413,21 @@ def sort(x, axis=-1, descending=False, name=None):
out1 = paddle.sort(x=x, axis=-1) out1 = paddle.sort(x=x, axis=-1)
out2 = paddle.sort(x=x, axis=0) out2 = paddle.sort(x=x, axis=0)
out3 = paddle.sort(x=x, axis=1) out3 = paddle.sort(x=x, axis=1)
print(out1[0].numpy()) print(out1.numpy())
#[[[5. 5. 8. 9.] #[[[5. 5. 8. 9.]
# [0. 0. 1. 7.] # [0. 0. 1. 7.]
# [2. 4. 6. 9.]] # [2. 4. 6. 9.]]
# [[2. 2. 4. 5.] # [[2. 2. 4. 5.]
# [4. 7. 7. 9.] # [4. 7. 7. 9.]
# [0. 1. 6. 7.]]] # [0. 1. 6. 7.]]]
print(out1[1].numpy()) print(out2.numpy())
#[[[0 3 1 2]
# [0 1 2 3]
# [2 3 0 1]]
# [[1 3 2 0]
# [0 1 2 3]
# [2 0 3 1]]]
print(out2[0].numpy())
#[[[5. 2. 4. 2.] #[[[5. 2. 4. 2.]
# [0. 0. 1. 7.] # [0. 0. 1. 7.]
# [1. 7. 0. 4.]] # [1. 7. 0. 4.]]
# [[5. 8. 9. 5.] # [[5. 8. 9. 5.]
# [4. 7. 7. 9.] # [4. 7. 7. 9.]
# [6. 9. 2. 6.]]] # [6. 9. 2. 6.]]]
print(out3[0].numpy()) print(out3.numpy())
#[[[0. 0. 1. 4.] #[[[0. 0. 1. 4.]
# [5. 8. 2. 5.] # [5. 8. 2. 5.]
# [6. 9. 9. 7.]] # [6. 9. 9. 7.]]
...@@ -447,8 +436,8 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -447,8 +436,8 @@ def sort(x, axis=-1, descending=False, name=None):
# [5. 7. 7. 9.]]] # [5. 7. 7. 9.]]]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
out, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending) out, _ = core.ops.argsort(x, 'axis', axis, 'descending', descending)
return out, ids return out
helper = LayerHelper("sort", **locals()) helper = LayerHelper("sort", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=False) dtype=x.dtype, stop_gradient=False)
...@@ -461,7 +450,7 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -461,7 +450,7 @@ def sort(x, axis=-1, descending=False, name=None):
'Indices': ids}, 'Indices': ids},
attrs={'axis': axis, attrs={'axis': axis,
'descending': descending}) 'descending': descending})
return out, ids return out
def where(condition, x, y, name=None): def where(condition, x, y, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册