diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 651037eeba2d28b1842e5ecbae3fbd5139516b93..6b262bfd4612d77132cec74e48f1c73c611cd609 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -21,6 +21,8 @@ from .elemwise import ( floor, greater, greater_equal, + isinf, + isnan, less, less_equal, log, diff --git a/python_module/megengine/functional/elemwise.py b/python_module/megengine/functional/elemwise.py index 6bed2d3d833ffa54ecb13336e610c0fed088c30e..2bb59255adcecada7f8ece1191f9cd919a6a82a2 100644 --- a/python_module/megengine/functional/elemwise.py +++ b/python_module/megengine/functional/elemwise.py @@ -27,6 +27,8 @@ __all__ = [ "greater", "greater_equal", "floor", + "isinf", + "isnan", "less", "less_equal", "log", @@ -244,3 +246,49 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: return maximum(inp, lower) else: return minimum(inp, upper) + + +def isnan(inp: Tensor) -> Tensor: + r"""Returns a new tensor representing if each element is NaN or not. + + :param: inp + :return: a new tensor representing if each element in :attr:`inp` is NaN or not. + + Examples: + + .. testcode:: + from megengine import tensor + import megengine.functional as F + + x = tensor([1, float("nan"), 0]) + + print(F.isnan(x)) + + .. testoutput:: + Tensor([0 1 0], dtype=uint8) + + """ + return (inp != inp).astype("uint8") + + +def isinf(inp: Tensor) -> Tensor: + r"""Returns a new tensor representing if each element is Inf or not. + + :param: inp + :return: a new tensor representing if each element in :attr:`inp` is Inf or not. + + Examples: + + .. testcode:: + from megengine import tensor + import megengine.functional as F + + x = tensor([1, float("inf"), 0]) + + print(F.isinf(x)) + + .. testoutput:: + Tensor([0 1 0], dtype=uint8) + + """ + return (abs(inp) == float("inf")).astype("uint8") diff --git a/python_module/test/unit/functional/test_elemwise.py b/python_module/test/unit/functional/test_elemwise.py index ef9cf6fad34d1bbcc52856eafaf5a3f00c17f7a6..c02bd58b608039a28b9b84ff8cf531dc43657974 100644 --- a/python_module/test/unit/functional/test_elemwise.py +++ b/python_module/test/unit/functional/test_elemwise.py @@ -53,3 +53,13 @@ def test_clamp(): x = np.linspace(-6, 6, dtype="float32") assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) + + +def test_isnan(): + for case in [[1, float("nan"), 0]]: + assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8")) + + +def test_isinf(): + for case in [[1, float("inf"), 0]]: + assertTensorClose(F.isinf(tensor(case)), np.isinf(case).astype("uint8"))