提交 000663c3 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/functional): add isnan and isinf oprs

GitOrigin-RevId: b4a347751c2022a2b0b2eca5a7e62c625b5e8c27
上级 ccc95ad2
......@@ -21,6 +21,8 @@ from .elemwise import (
floor,
greater,
greater_equal,
isinf,
isnan,
less,
less_equal,
log,
......
......@@ -27,6 +27,8 @@ __all__ = [
"greater",
"greater_equal",
"floor",
"isinf",
"isnan",
"less",
"less_equal",
"log",
......@@ -244,3 +246,49 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
return maximum(inp, lower)
else:
return minimum(inp, upper)
def isnan(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is NaN or not.
:param: inp
:return: a new tensor representing if each element in :attr:`inp` is NaN or not.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
x = tensor([1, float("nan"), 0])
print(F.isnan(x))
.. testoutput::
Tensor([0 1 0], dtype=uint8)
"""
return (inp != inp).astype("uint8")
def isinf(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is Inf or not.
:param: inp
:return: a new tensor representing if each element in :attr:`inp` is Inf or not.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
x = tensor([1, float("inf"), 0])
print(F.isinf(x))
.. testoutput::
Tensor([0 1 0], dtype=uint8)
"""
return (abs(inp) == float("inf")).astype("uint8")
......@@ -53,3 +53,13 @@ def test_clamp():
x = np.linspace(-6, 6, dtype="float32")
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6))
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0))
def test_isnan():
for case in [[1, float("nan"), 0]]:
assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8"))
def test_isinf():
for case in [[1, float("inf"), 0]]:
assertTensorClose(F.isinf(tensor(case)), np.isinf(case).astype("uint8"))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册