From ec582895501b1ae4da110ce6b9fcb61ddcacb718 Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 9 Mar 2022 17:57:32 +0800 Subject: [PATCH] fix the full_like with fill the value of inf (#40232) * fix the full_like with fill the value of inf * update the test case for the fill_any_like * updae the comments for the full_like --- paddle/phi/kernels/cpu/full_kernel.cc | 20 ++++++++++++++----- paddle/phi/kernels/gpu/full_kernel.cu | 20 ++++++++++++++----- .../tests/unittests/test_fill_any_like_op.py | 13 ------------ .../tests/unittests/test_full_like_op.py | 9 +++++++++ 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 86576a861aa..556de3adcf4 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -54,12 +54,22 @@ void FullLikeKernel(const Context& dev_ctx, auto common_type_value = static_cast(value); - PADDLE_ENFORCE_EQ( - (common_type_value >= + // Check whether the filled value is valid + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + + if ((common_type_value >= static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } + + PADDLE_ENFORCE_EQ( + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index a905979f08b..852d209ee01 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -71,12 +71,22 @@ void FullLikeKernel(const Context& dev_ctx, auto common_type_value = static_cast(value); - PADDLE_ENFORCE_EQ( - (common_type_value >= + // Check whether the filled value is valid + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + + if ((common_type_value >= static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } + + PADDLE_ENFORCE_EQ( + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index 9be2e57ff0c..95537d43327 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -98,19 +98,6 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp): } -class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp): - def init(self): - self.value = 1e100 - - def test_check_output(self): - exception = None - try: - self.check_output(check_dygraph=False) - except ValueError as ex: - exception = ex - self.assertIsNotNone(exception) - - class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): def init(self): self.dtype = np.float16 diff --git a/python/paddle/fluid/tests/unittests/test_full_like_op.py b/python/paddle/fluid/tests/unittests/test_full_like_op.py index be6abb17c3c..3ae2e9ff6bd 100644 --- a/python/paddle/fluid/tests/unittests/test_full_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_like_op.py @@ -62,6 +62,15 @@ class TestFullOp(unittest.TestCase): self.assertTrue((out.numpy() == out_numpy).all(), True) paddle.enable_static() + def test_full_like_fill_inf(self): + paddle.disable_static() + input = paddle.arange(6, 10, dtype='float32') + out = paddle.full_like(input, fill_value=float('inf')) + out_numpy = np.random.random((4)).astype("float32") + out_numpy.fill(float('inf')) + self.assertTrue((out.numpy() == out_numpy).all(), True) + paddle.enable_static() + class TestFullOpError(unittest.TestCase): def test_errors(self): -- GitLab