diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index 86576a861aa4834a4b39b50594565a2d4b3ac510..556de3adcf498a111bcaad2f5b849bfe79be8adf 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -54,12 +54,22 @@ void FullLikeKernel(const Context& dev_ctx, auto common_type_value = static_cast(value); - PADDLE_ENFORCE_EQ( - (common_type_value >= + // Check whether the filled value is valid + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + + if ((common_type_value >= static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } + + PADDLE_ENFORCE_EQ( + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index a905979f08b5f1b7c619c0f6a4b330f145ed8674..852d209ee018598285b3dff35ffd51e289f07976 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -71,12 +71,22 @@ void FullLikeKernel(const Context& dev_ctx, auto common_type_value = static_cast(value); - PADDLE_ENFORCE_EQ( - (common_type_value >= + // Check whether the filled value is valid + bool is_out_range = true; + if (std::isinf(value) || std::isnan(value)) { + is_out_range = false; + } + + if ((common_type_value >= static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, + (common_type_value <= + static_cast(std::numeric_limits::max()))) { + is_out_range = false; + } + + PADDLE_ENFORCE_EQ( + is_out_range, + false, phi::errors::InvalidArgument( "The filled value is out of range for target type, " "current kernel type is %s, the range should between %f " diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index 9be2e57ff0cba10358a7873f701d93237c55fd16..95537d4332739be13ee1705dafeea6d3f0ac2762 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -98,19 +98,6 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp): } -class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp): - def init(self): - self.value = 1e100 - - def test_check_output(self): - exception = None - try: - self.check_output(check_dygraph=False) - except ValueError as ex: - exception = ex - self.assertIsNotNone(exception) - - class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): def init(self): self.dtype = np.float16 diff --git a/python/paddle/fluid/tests/unittests/test_full_like_op.py b/python/paddle/fluid/tests/unittests/test_full_like_op.py index be6abb17c3c316d30bdcfa5539ecd1e2549280a5..3ae2e9ff6bdaf7d5a4161eb70914191ace7df417 100644 --- a/python/paddle/fluid/tests/unittests/test_full_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_like_op.py @@ -62,6 +62,15 @@ class TestFullOp(unittest.TestCase): self.assertTrue((out.numpy() == out_numpy).all(), True) paddle.enable_static() + def test_full_like_fill_inf(self): + paddle.disable_static() + input = paddle.arange(6, 10, dtype='float32') + out = paddle.full_like(input, fill_value=float('inf')) + out_numpy = np.random.random((4)).astype("float32") + out_numpy.fill(float('inf')) + self.assertTrue((out.numpy() == out_numpy).all(), True) + paddle.enable_static() + class TestFullOpError(unittest.TestCase): def test_errors(self):