未验证 提交 ec582895 编写于 作者: W wawltor 提交者: GitHub

fix the full_like with fill the value of inf (#40232)

* fix the full_like with fill the value of inf

* update the test case for the fill_any_like

* updae the comments for the full_like
上级 95c343d3
......@@ -54,12 +54,22 @@ void FullLikeKernel(const Context& dev_ctx,
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
// Check whether the filled value is valid
bool is_out_range = true;
if (std::isinf(value) || std::isnan(value)) {
is_out_range = false;
}
if ((common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max()))) {
is_out_range = false;
}
PADDLE_ENFORCE_EQ(
is_out_range,
false,
phi::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
......
......@@ -71,12 +71,22 @@ void FullLikeKernel(const Context& dev_ctx,
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
// Check whether the filled value is valid
bool is_out_range = true;
if (std::isinf(value) || std::isnan(value)) {
is_out_range = false;
}
if ((common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max()))) {
is_out_range = false;
}
PADDLE_ENFORCE_EQ(
is_out_range,
false,
phi::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
......
......@@ -98,19 +98,6 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp):
}
class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp):
def init(self):
self.value = 1e100
def test_check_output(self):
exception = None
try:
self.check_output(check_dygraph=False)
except ValueError as ex:
exception = ex
self.assertIsNotNone(exception)
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float16
......
......@@ -62,6 +62,15 @@ class TestFullOp(unittest.TestCase):
self.assertTrue((out.numpy() == out_numpy).all(), True)
paddle.enable_static()
def test_full_like_fill_inf(self):
paddle.disable_static()
input = paddle.arange(6, 10, dtype='float32')
out = paddle.full_like(input, fill_value=float('inf'))
out_numpy = np.random.random((4)).astype("float32")
out_numpy.fill(float('inf'))
self.assertTrue((out.numpy() == out_numpy).all(), True)
paddle.enable_static()
class TestFullOpError(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册