diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 41fcf3750878e61616caff84c4f44d18d1d36815..239083f88d9c63c4790fe4dd3060a5cf473ff73c 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include #include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -45,15 +47,22 @@ class FillConstantKernel : public framework::OpKernel { if (str_value.empty()) { value = static_cast(float_value); } else { - std::stringstream convert_stream(str_value); - if (std::is_same::value) { - int64_t tmp_value; - convert_stream >> tmp_value; - value = static_cast(tmp_value); + // handle NaN/Inf first, which cannot be read from stream. + if (str_value == "inf") { + value = static_cast(std::numeric_limits::infinity()); + } else if (str_value == "nan") { + value = static_cast(std::numeric_limits::quiet_NaN()); } else { - double tmp_value; - convert_stream >> tmp_value; - value = static_cast(tmp_value); + std::stringstream convert_stream(str_value); + if (std::is_same::value) { + int64_t tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } else { + double tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } } } if (ctx.HasInput("ValueTensor")) { diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 43069470680c7d49071ce54bf3649962c56f06ea..babfcdb9040df78f13204bbcad28ca01cff48040 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -330,6 +330,16 @@ class TestFillConstantImperative(unittest.TestCase): res4.numpy(), np.full( [1, 2], 88, dtype="int32")) + def test_nan(self): + with fluid.dygraph.guard(): + res = fluid.layers.fill_constant([1], 'float32', np.nan) + self.assertTrue(np.isnan(res.numpy().item(0))) + + def test_inf(self): + with fluid.dygraph.guard(): + res = fluid.layers.fill_constant([1], 'float32', np.inf) + self.assertTrue(np.isinf(res.numpy().item(0))) + class TestFillConstantOpError(unittest.TestCase): def test_errors(self):