未验证 提交 85df6d73 编写于 作者: Y yeliang2258 提交者: GitHub

Add inf and nan support in equal OP (#44667)

* add inf and nan support in equal

* add header

* fix nan and update test

* update test

* update test

* update test

* update code

* update compare test

* update func

* update

* update

* fix

* update
上级 ffd8adca
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <math.h>
namespace phi {
namespace funcs {
......@@ -35,6 +36,10 @@ template <typename InT, typename OutT = bool>
struct EqualFunctor {
HOSTDEVICE OutT operator()(const InT a, const InT b) const {
if (std::is_floating_point<InT>::value) {
if (isinf(static_cast<float>(a)) || isinf(static_cast<float>(b)))
return static_cast<OutT>(a == b);
if (isnan(static_cast<float>(a)) || isnan(static_cast<float>(b)))
return static_cast<OutT>(false);
return static_cast<OutT>(fabs(static_cast<double>(a - b)) < 1e-8);
} else {
return static_cast<OutT>(a == b);
......
......@@ -150,6 +150,102 @@ def create_paddle_case(op_type, callback):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()
def test_dynamic_api_inf_1(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('inf'), float('inf')]).astype(np.int64)
x = paddle.to_tensor(x1)
y1 = np.array([1, float('-inf'), float('inf')]).astype(np.int64)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_dynamic_api_inf_2(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('inf'),
float('inf')]).astype(np.float32)
x = paddle.to_tensor(x1)
y1 = np.array([1, float('-inf'),
float('inf')]).astype(np.float32)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_dynamic_api_inf_3(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('inf'),
float('-inf')]).astype(np.float32)
x = paddle.to_tensor(x1)
y1 = np.array([1, 2, 3]).astype(np.float32)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_dynamic_api_nan_1(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('nan'), float('nan')]).astype(np.int64)
x = paddle.to_tensor(x1)
y1 = np.array([1, float('-nan'), float('nan')]).astype(np.int64)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_dynamic_api_nan_2(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('nan'),
float('nan')]).astype(np.float32)
x = paddle.to_tensor(x1)
y1 = np.array([1, float('-nan'),
float('nan')]).astype(np.float32)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_dynamic_api_nan_3(self):
if self.op_type == "equal":
paddle.disable_static()
x1 = np.array([1, float('-nan'),
float('nan')]).astype(np.float32)
x = paddle.to_tensor(x1)
y1 = np.array([1, 2, 1]).astype(np.float32)
y = paddle.to_tensor(y1)
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
self.real_result = (x1 == y1).astype(np.int64)
self.assertEqual(
(out.numpy().astype(np.int64) == self.real_result).all(),
True)
paddle.enable_static()
def test_not_equal(self):
if self.op_type == "not_equal":
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册