diff --git a/paddle/phi/kernels/funcs/compare_functors.h b/paddle/phi/kernels/funcs/compare_functors.h index 569fed7b7fbab90e77c02beecbb18b486c801a4f..e16083506bbe054accad3a9b76ce533fb3d74ac9 100644 --- a/paddle/phi/kernels/funcs/compare_functors.h +++ b/paddle/phi/kernels/funcs/compare_functors.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include namespace phi { namespace funcs { @@ -35,6 +36,10 @@ template struct EqualFunctor { HOSTDEVICE OutT operator()(const InT a, const InT b) const { if (std::is_floating_point::value) { + if (isinf(static_cast(a)) || isinf(static_cast(b))) + return static_cast(a == b); + if (isnan(static_cast(a)) || isnan(static_cast(b))) + return static_cast(false); return static_cast(fabs(static_cast(a - b)) < 1e-8); } else { return static_cast(a == b); diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index f1bacfbb6f8f8a203ec0295ffd6a5ea43ffd7522..731eedfca60a08bdedff5192d2703d4b3a0e371d 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -150,6 +150,102 @@ def create_paddle_case(op_type, callback): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + def test_dynamic_api_inf_1(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('inf'), float('inf')]).astype(np.int64) + x = paddle.to_tensor(x1) + y1 = np.array([1, float('-inf'), float('inf')]).astype(np.int64) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + + def test_dynamic_api_inf_2(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('inf'), + float('inf')]).astype(np.float32) + x = paddle.to_tensor(x1) + y1 = np.array([1, float('-inf'), + float('inf')]).astype(np.float32) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + + def test_dynamic_api_inf_3(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('inf'), + float('-inf')]).astype(np.float32) + x = paddle.to_tensor(x1) + y1 = np.array([1, 2, 3]).astype(np.float32) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + + def test_dynamic_api_nan_1(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('nan'), float('nan')]).astype(np.int64) + x = paddle.to_tensor(x1) + y1 = np.array([1, float('-nan'), float('nan')]).astype(np.int64) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + + def test_dynamic_api_nan_2(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('nan'), + float('nan')]).astype(np.float32) + x = paddle.to_tensor(x1) + y1 = np.array([1, float('-nan'), + float('nan')]).astype(np.float32) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + + def test_dynamic_api_nan_3(self): + if self.op_type == "equal": + paddle.disable_static() + x1 = np.array([1, float('-nan'), + float('nan')]).astype(np.float32) + x = paddle.to_tensor(x1) + y1 = np.array([1, 2, 1]).astype(np.float32) + y = paddle.to_tensor(y1) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.real_result = (x1 == y1).astype(np.int64) + self.assertEqual( + (out.numpy().astype(np.int64) == self.real_result).all(), + True) + paddle.enable_static() + def test_not_equal(self): if self.op_type == "not_equal": paddle.disable_static()