未验证 提交 a911c19e 编写于 作者: L Leo Chen 提交者: GitHub

fill_constant op supports NaN and Inf (#28109)

* fill_constant supports nan and inf

* add ut
上级 74c8a811
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <limits>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -44,6 +46,12 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -44,6 +46,12 @@ class FillConstantKernel : public framework::OpKernel<T> {
T value; T value;
if (str_value.empty()) { if (str_value.empty()) {
value = static_cast<T>(float_value); value = static_cast<T>(float_value);
} else {
// handle NaN/Inf first, which cannot be read from stream.
if (str_value == "inf") {
value = static_cast<T>(std::numeric_limits<double>::infinity());
} else if (str_value == "nan") {
value = static_cast<T>(std::numeric_limits<double>::quiet_NaN());
} else { } else {
std::stringstream convert_stream(str_value); std::stringstream convert_stream(str_value);
if (std::is_same<int64_t, T>::value) { if (std::is_same<int64_t, T>::value) {
...@@ -56,6 +64,7 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -56,6 +64,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
value = static_cast<T>(tmp_value); value = static_cast<T>(tmp_value);
} }
} }
}
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
auto *value_tensor = ctx.Input<framework::Tensor>("ValueTensor"); auto *value_tensor = ctx.Input<framework::Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -330,6 +330,16 @@ class TestFillConstantImperative(unittest.TestCase): ...@@ -330,6 +330,16 @@ class TestFillConstantImperative(unittest.TestCase):
res4.numpy(), np.full( res4.numpy(), np.full(
[1, 2], 88, dtype="int32")) [1, 2], 88, dtype="int32"))
def test_nan(self):
with fluid.dygraph.guard():
res = fluid.layers.fill_constant([1], 'float32', np.nan)
self.assertTrue(np.isnan(res.numpy().item(0)))
def test_inf(self):
with fluid.dygraph.guard():
res = fluid.layers.fill_constant([1], 'float32', np.inf)
self.assertTrue(np.isinf(res.numpy().item(0)))
class TestFillConstantOpError(unittest.TestCase): class TestFillConstantOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册