diff --git a/paddle/phi/kernels/cpu/argsort_kernel.cc b/paddle/phi/kernels/cpu/argsort_kernel.cc index 0e69afe38c9ad4d1ccfbd42fcde06562d97a4e3f..8621a717e1018f8a6e9a73b6e7440a1331cc63bd 100644 --- a/paddle/phi/kernels/cpu/argsort_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_kernel.cc @@ -51,9 +51,13 @@ static void FullSort(Type input_height, col_vec.end(), [&](const std::pair& l, const std::pair& r) { if (descending) - return l.first > r.first; + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); else - return l.first < r.first; + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); }); for (Type j = 0; j < input_width; ++j) { diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py index ee2d65684f9ba5d7234c5203e3970c7bd89ed673..874d66112bdbbd611bf06056b8c2ee9f564c658c 100644 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -442,5 +442,28 @@ class TestArgsortImperative4(TestArgsortImperative): self.axis = 1 +class TestArgsortWithInputNaN(unittest.TestCase): + def init(self): + self.axis = 0 + + def setUp(self): + self.init() + self.input_data = np.array([1.0, np.nan, 3.0, 2.0]) + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + def test_api(self): + paddle.disable_static(self.place) + var_x = paddle.to_tensor(self.input_data) + out = paddle.argsort(var_x, axis=self.axis) + self.assertEqual((out.numpy() == np.array([0, 3, 2, 1])).all(), True) + + out = paddle.argsort(var_x, axis=self.axis, descending=True) + self.assertEqual((out.numpy() == np.array([1, 2, 3, 0])).all(), True) + paddle.enable_static() + + if __name__ == "__main__": unittest.main()