From 1d9ee667850e9a29ef2e1c4622bd258e36486b4a Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 19 Apr 2022 10:04:46 +0800 Subject: [PATCH] [Eager] paddle.sort interface use final_state (#41934) * [Eager] paddle.sort interface use final_state * Add eager test case for paddle.sort() --- .../paddle/fluid/tests/unittests/test_sort_op.py | 15 +++++++++++++-- python/paddle/tensor/search.py | 12 ++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sort_op.py b/python/paddle/fluid/tests/unittests/test_sort_op.py index 366e0c7a3f..d678aa835d 100644 --- a/python/paddle/fluid/tests/unittests/test_sort_op.py +++ b/python/paddle/fluid/tests/unittests/test_sort_op.py @@ -21,6 +21,7 @@ import paddle.fluid.layers as layers import numpy as np import six import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard class TestSortOnCPU(unittest.TestCase): @@ -70,14 +71,19 @@ class TestSortDygraph(unittest.TestCase): else: self.place = core.CPUPlace() - def test_api_0(self): + def func_api_0(self): paddle.disable_static(self.place) var_x = paddle.to_tensor(self.input_data) out = paddle.sort(var_x) self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True) paddle.enable_static() - def test_api_1(self): + def test_api_0(self): + with _test_eager_guard(): + self.func_api_0() + self.func_api_0() + + def func_api_1(self): paddle.disable_static(self.place) var_x = paddle.to_tensor(self.input_data) out = paddle.sort(var_x, axis=-1) @@ -85,3 +91,8 @@ class TestSortDygraph(unittest.TestCase): (np.sort( self.input_data, axis=-1) == out.numpy()).all(), True) paddle.enable_static() + + def test_api_1(self): + with _test_eager_guard(): + self.func_api_1() + self.func_api_1() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 04704981c8..d86a6a3f62 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -92,7 +92,7 @@ def argsort(x, axis=-1, descending=False, name=None): # [0 2 1 1]]] """ if in_dygraph_mode(): - _, ids, = _C_ops.final_state_argsort(x, axis, descending) + _, ids = _C_ops.final_state_argsort(x, axis, descending) return ids if _in_legacy_dygraph(): @@ -482,9 +482,13 @@ def sort(x, axis=-1, descending=False, name=None): # [4. 7. 4. 6.] # [5. 7. 7. 9.]]] """ - if paddle.in_dynamic_mode(): - out, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending) - return out + if in_dygraph_mode(): + outs, _ = _C_ops.final_state_argsort(x, axis, descending) + return outs + + if _in_legacy_dygraph(): + outs, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending) + return outs helper = LayerHelper("sort", **locals()) out = helper.create_variable_for_type_inference( dtype=x.dtype, stop_gradient=False) -- GitLab