diff --git a/python/paddle/fluid/tests/unittests/test_sort_op.py b/python/paddle/fluid/tests/unittests/test_sort_op.py index 366e0c7a3fa3ee714162e6041aa0d52dbfb30746..d678aa835d544d76916526f53661c39d691d878e 100644 --- a/python/paddle/fluid/tests/unittests/test_sort_op.py +++ b/python/paddle/fluid/tests/unittests/test_sort_op.py @@ -21,6 +21,7 @@ import paddle.fluid.layers as layers import numpy as np import six import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard class TestSortOnCPU(unittest.TestCase): @@ -70,14 +71,19 @@ class TestSortDygraph(unittest.TestCase): else: self.place = core.CPUPlace() - def test_api_0(self): + def func_api_0(self): paddle.disable_static(self.place) var_x = paddle.to_tensor(self.input_data) out = paddle.sort(var_x) self.assertEqual((np.sort(self.input_data) == out.numpy()).all(), True) paddle.enable_static() - def test_api_1(self): + def test_api_0(self): + with _test_eager_guard(): + self.func_api_0() + self.func_api_0() + + def func_api_1(self): paddle.disable_static(self.place) var_x = paddle.to_tensor(self.input_data) out = paddle.sort(var_x, axis=-1) @@ -85,3 +91,8 @@ class TestSortDygraph(unittest.TestCase): (np.sort( self.input_data, axis=-1) == out.numpy()).all(), True) paddle.enable_static() + + def test_api_1(self): + with _test_eager_guard(): + self.func_api_1() + self.func_api_1() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 04704981c89a19a2e7588e75e2cca0e1b4957541..d86a6a3f627b34ec3ea0d9a233ba6de9edb50a57 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -92,7 +92,7 @@ def argsort(x, axis=-1, descending=False, name=None): # [0 2 1 1]]] """ if in_dygraph_mode(): - _, ids, = _C_ops.final_state_argsort(x, axis, descending) + _, ids = _C_ops.final_state_argsort(x, axis, descending) return ids if _in_legacy_dygraph(): @@ -482,9 +482,13 @@ def sort(x, axis=-1, descending=False, name=None): # [4. 7. 4. 6.] # [5. 7. 7. 9.]]] """ - if paddle.in_dynamic_mode(): - out, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending) - return out + if in_dygraph_mode(): + outs, _ = _C_ops.final_state_argsort(x, axis, descending) + return outs + + if _in_legacy_dygraph(): + outs, _ = _C_ops.argsort(x, 'axis', axis, 'descending', descending) + return outs helper = LayerHelper("sort", **locals()) out = helper.create_variable_for_type_inference( dtype=x.dtype, stop_gradient=False)