未验证 提交 7fcb32dd 编写于 作者: L Leo Chen 提交者: GitHub

fill_constant op supports NINF (#28270)

上级 495a9ceb
...@@ -50,6 +50,8 @@ class FillConstantKernel : public framework::OpKernel<T> { ...@@ -50,6 +50,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
// handle NaN/Inf first, which cannot be read from stream. // handle NaN/Inf first, which cannot be read from stream.
if (str_value == "inf") { if (str_value == "inf") {
value = static_cast<T>(std::numeric_limits<double>::infinity()); value = static_cast<T>(std::numeric_limits<double>::infinity());
} else if (str_value == "-inf") {
value = static_cast<T>(-std::numeric_limits<double>::infinity());
} else if (str_value == "nan") { } else if (str_value == "nan") {
value = static_cast<T>(std::numeric_limits<double>::quiet_NaN()); value = static_cast<T>(std::numeric_limits<double>::quiet_NaN());
} else { } else {
......
...@@ -340,6 +340,12 @@ class TestFillConstantImperative(unittest.TestCase): ...@@ -340,6 +340,12 @@ class TestFillConstantImperative(unittest.TestCase):
res = fluid.layers.fill_constant([1], 'float32', np.inf) res = fluid.layers.fill_constant([1], 'float32', np.inf)
self.assertTrue(np.isinf(res.numpy().item(0))) self.assertTrue(np.isinf(res.numpy().item(0)))
def test_ninf(self):
with fluid.dygraph.guard():
res = fluid.layers.fill_constant([1], 'float32', np.NINF)
self.assertTrue(np.isinf(res.numpy().item(0)))
self.assertEqual(np.NINF, res.numpy().item(0))
class TestFillConstantOpError(unittest.TestCase): class TestFillConstantOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册