未验证 提交 17af293f 编写于 作者: W wawltor 提交者: GitHub

Fix argsort cpu kernel when with input of NaN (#41070)

* fix the argosrt cpu

* add the test case for the paddle.argsort
上级 157c1a28
......@@ -51,9 +51,13 @@ static void FullSort(Type input_height,
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return l.first > r.first;
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return l.first < r.first;
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
for (Type j = 0; j < input_width; ++j) {
......
......@@ -442,5 +442,28 @@ class TestArgsortImperative4(TestArgsortImperative):
self.axis = 1
class TestArgsortWithInputNaN(unittest.TestCase):
def init(self):
self.axis = 0
def setUp(self):
self.init()
self.input_data = np.array([1.0, np.nan, 3.0, 2.0])
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
def test_api(self):
paddle.disable_static(self.place)
var_x = paddle.to_tensor(self.input_data)
out = paddle.argsort(var_x, axis=self.axis)
self.assertEqual((out.numpy() == np.array([0, 3, 2, 1])).all(), True)
out = paddle.argsort(var_x, axis=self.axis, descending=True)
self.assertEqual((out.numpy() == np.array([1, 2, 3, 0])).all(), True)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册