未验证 提交 a43b0d15 编写于 作者: W wawltor 提交者: GitHub

Update the code for the sort api

update the sort api, delete unused ouput index tensor
上级 dca56f47
......@@ -37,7 +37,7 @@ class TestSortOnCPU(unittest.TestCase):
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]])
result, = exe.run(feed={'input': data}, fetch_list=[output])
np_result = np.sort(result)
self.assertEqual((result == np_result).all(), True)
......@@ -50,7 +50,7 @@ class TestSortOnCPU(unittest.TestCase):
[[[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]],
[[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]]],
dtype='float32')
result, = exe.run(feed={'input': data}, fetch_list=[output[0]])
result, = exe.run(feed={'input': data}, fetch_list=[output])
np_result = np.sort(result, axis=1)
self.assertEqual((result == np_result).all(), True)
......@@ -75,7 +75,7 @@ class TestSortDygraph(unittest.TestCase):
with imperative.guard(self.place):
var_x = imperative.to_variable(self.input_data)
out = paddle.sort(var_x)
self.assertEqual((np.sort(self.input_data) == out[0].numpy()).all(),
self.assertEqual((np.sort(self.input_data) == out.numpy()).all(),
True)
def test_api_1(self):
......@@ -84,5 +84,4 @@ class TestSortDygraph(unittest.TestCase):
out = paddle.sort(var_x, axis=-1)
self.assertEqual(
(np.sort(
self.input_data, axis=-1) == out[0].numpy()).all(),
True)
self.input_data, axis=-1) == out.numpy()).all(), True)
......@@ -46,8 +46,7 @@ def argsort(x, axis=-1, descending=False, name=None):
:alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as ``x``.
This OP sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
......@@ -381,8 +380,7 @@ def sort(x, axis=-1, descending=False, name=None):
:alias_main: paddle.sort
:alias: paddle.sort,paddle.tensor.sort,paddle.tensor.search.sort
This OP sorts the input along the given axis, and returns sorted output
data Tensor and its corresponding index Tensor with the same shape as ``x``.
This OP sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Args:
x(Tensor): An input N-D Tensor with type float32, float64, int16,
......@@ -397,9 +395,7 @@ def sort(x, axis=-1, descending=False, name=None):
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
tuple: A tuple of sorted data tensor(with the same shape and data
type as ``x``) and the sorted indices(with the same shape as ``x``
and with data type int64).
Tensor: sorted tensor(with the same shape and data type as ``x``).
Examples:
.. code-block:: python
import paddle
......@@ -417,28 +413,21 @@ def sort(x, axis=-1, descending=False, name=None):
out1 = paddle.sort(x=x, axis=-1)
out2 = paddle.sort(x=x, axis=0)
out3 = paddle.sort(x=x, axis=1)
print(out1[0].numpy())
print(out1.numpy())
#[[[5. 5. 8. 9.]
# [0. 0. 1. 7.]
# [2. 4. 6. 9.]]
# [[2. 2. 4. 5.]
# [4. 7. 7. 9.]
# [0. 1. 6. 7.]]]
print(out1[1].numpy())
#[[[0 3 1 2]
# [0 1 2 3]
# [2 3 0 1]]
# [[1 3 2 0]
# [0 1 2 3]
# [2 0 3 1]]]
print(out2[0].numpy())
print(out2.numpy())
#[[[5. 2. 4. 2.]
# [0. 0. 1. 7.]
# [1. 7. 0. 4.]]
# [[5. 8. 9. 5.]
# [4. 7. 7. 9.]
# [6. 9. 2. 6.]]]
print(out3[0].numpy())
print(out3.numpy())
#[[[0. 0. 1. 4.]
# [5. 8. 2. 5.]
# [6. 9. 9. 7.]]
......@@ -447,8 +436,8 @@ def sort(x, axis=-1, descending=False, name=None):
# [5. 7. 7. 9.]]]
"""
if in_dygraph_mode():
out, ids = core.ops.argsort(x, 'axis', axis, 'descending', descending)
return out, ids
out, _ = core.ops.argsort(x, 'axis', axis, 'descending', descending)
return out
helper = LayerHelper("sort", **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=False)
......@@ -461,7 +450,7 @@ def sort(x, axis=-1, descending=False, name=None):
'Indices': ids},
attrs={'axis': axis,
'descending': descending})
return out, ids
return out
def where(condition, x, y, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册