From a911c19eb03755431e6416c4da3423ebd9c7e716 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 20 Oct 2020 04:48:56 -0500 Subject: [PATCH] fill_constant op supports NaN and Inf (#28109) * fill_constant supports nan and inf * add ut --- paddle/fluid/operators/fill_constant_op.h | 25 +++++++++++++------ .../tests/unittests/test_fill_constant_op.py | 10 ++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 41fcf375087..239083f88d9 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include #include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -45,15 +47,22 @@ class FillConstantKernel : public framework::OpKernel { if (str_value.empty()) { value = static_cast(float_value); } else { - std::stringstream convert_stream(str_value); - if (std::is_same::value) { - int64_t tmp_value; - convert_stream >> tmp_value; - value = static_cast(tmp_value); + // handle NaN/Inf first, which cannot be read from stream. + if (str_value == "inf") { + value = static_cast(std::numeric_limits::infinity()); + } else if (str_value == "nan") { + value = static_cast(std::numeric_limits::quiet_NaN()); } else { - double tmp_value; - convert_stream >> tmp_value; - value = static_cast(tmp_value); + std::stringstream convert_stream(str_value); + if (std::is_same::value) { + int64_t tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } else { + double tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } } } if (ctx.HasInput("ValueTensor")) { diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 43069470680..babfcdb9040 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -330,6 +330,16 @@ class TestFillConstantImperative(unittest.TestCase): res4.numpy(), np.full( [1, 2], 88, dtype="int32")) + def test_nan(self): + with fluid.dygraph.guard(): + res = fluid.layers.fill_constant([1], 'float32', np.nan) + self.assertTrue(np.isnan(res.numpy().item(0))) + + def test_inf(self): + with fluid.dygraph.guard(): + res = fluid.layers.fill_constant([1], 'float32', np.inf) + self.assertTrue(np.isinf(res.numpy().item(0))) + class TestFillConstantOpError(unittest.TestCase): def test_errors(self): -- GitLab