未验证 提交 5e8d7804 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] paddle.sort interface use final_state (#41934) (#41955)

* [Eager] paddle.sort interface use final_state

* Add eager test case for paddle.sort()
上级 21c333df
......@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers
import numpy as np
import six
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
class TestSortOnCPU(unittest.TestCase):
......@@ -70,14 +71,19 @@ class TestSortDygraph(unittest.TestCase):
else:
self.place = core.CPUPlace()
def test_api_0(self):
def func_api_0(self):
paddle.disable_static(self.place)
var_x = paddle.to_tensor(self.input_data)
out = paddle.sort(var_x)
self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True)
paddle.enable_static()
def test_api_1(self):
def test_api_0(self):
with _test_eager_guard():
self.func_api_0()
self.func_api_0()
def func_api_1(self):
paddle.disable_static(self.place)
var_x = paddle.to_tensor(self.input_data)
out = paddle.sort(var_x, axis=-1)
......@@ -85,3 +91,8 @@ class TestSortDygraph(unittest.TestCase):
(np.sort(
self.input_data, axis=-1) == out.numpy()).all(), True)
paddle.enable_static()
def test_api_1(self):
with _test_eager_guard():
self.func_api_1()
self.func_api_1()
......@@ -92,7 +92,7 @@ def argsort(x, axis=-1, descending=False, name=None):
# [0 2 1 1]]]
"""
if in_dygraph_mode():
_, ids, = _C_ops.final_state_argsort(x, axis, descending)
_, ids = _C_ops.final_state_argsort(x, axis, descending)
return ids
if _in_legacy_dygraph():
......@@ -472,9 +472,13 @@ def sort(x, axis=-1, descending=False, name=None):
# [4. 7. 4. 6.]
# [5. 7. 7. 9.]]]
"""
if paddle.in_dynamic_mode():
out, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending)
return out
if in_dygraph_mode():
outs, _ = _C_ops.final_state_argsort(x, axis, descending)
return outs
if _in_legacy_dygraph():
outs, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending)
return outs
helper = LayerHelper("sort", **locals())
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册