未验证 提交 5e8d7804 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] paddle.sort interface use final_state (#41934) (#41955)

* [Eager] paddle.sort interface use final_state

* Add eager test case for paddle.sort()
上级 21c333df
...@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers ...@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers
import numpy as np import numpy as np
import six import six
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
class TestSortOnCPU(unittest.TestCase): class TestSortOnCPU(unittest.TestCase):
...@@ -70,14 +71,19 @@ class TestSortDygraph(unittest.TestCase): ...@@ -70,14 +71,19 @@ class TestSortDygraph(unittest.TestCase):
else: else:
self.place = core.CPUPlace() self.place = core.CPUPlace()
def test_api_0(self): def func_api_0(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
var_x = paddle.to_tensor(self.input_data) var_x = paddle.to_tensor(self.input_data)
out = paddle.sort(var_x) out = paddle.sort(var_x)
self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True) self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True)
paddle.enable_static() paddle.enable_static()
def test_api_1(self): def test_api_0(self):
with _test_eager_guard():
self.func_api_0()
self.func_api_0()
def func_api_1(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
var_x = paddle.to_tensor(self.input_data) var_x = paddle.to_tensor(self.input_data)
out = paddle.sort(var_x, axis=-1) out = paddle.sort(var_x, axis=-1)
...@@ -85,3 +91,8 @@ class TestSortDygraph(unittest.TestCase): ...@@ -85,3 +91,8 @@ class TestSortDygraph(unittest.TestCase):
(np.sort( (np.sort(
self.input_data, axis=-1) == out.numpy()).all(), True) self.input_data, axis=-1) == out.numpy()).all(), True)
paddle.enable_static() paddle.enable_static()
def test_api_1(self):
with _test_eager_guard():
self.func_api_1()
self.func_api_1()
...@@ -92,7 +92,7 @@ def argsort(x, axis=-1, descending=False, name=None): ...@@ -92,7 +92,7 @@ def argsort(x, axis=-1, descending=False, name=None):
# [0 2 1 1]]] # [0 2 1 1]]]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
_, ids, = _C_ops.final_state_argsort(x, axis, descending) _, ids = _C_ops.final_state_argsort(x, axis, descending)
return ids return ids
if _in_legacy_dygraph(): if _in_legacy_dygraph():
...@@ -472,9 +472,13 @@ def sort(x, axis=-1, descending=False, name=None): ...@@ -472,9 +472,13 @@ def sort(x, axis=-1, descending=False, name=None):
# [4. 7. 4. 6.] # [4. 7. 4. 6.]
# [5. 7. 7. 9.]]] # [5. 7. 7. 9.]]]
""" """
if paddle.in_dynamic_mode(): if in_dygraph_mode():
out, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending) outs, _ = _C_ops.final_state_argsort(x, axis, descending)
return out return outs
if _in_legacy_dygraph():
outs, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending)
return outs
helper = LayerHelper("sort", **locals()) helper = LayerHelper("sort", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=False) dtype=x.dtype, stop_gradient=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册