diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 6a61989c6e25e80bb6709268825a9b07cd80d29c..35a49e25616efe62f24da1c1a272a7b5c88c8679 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -3,6 +3,8 @@ import collections import math from typing import Iterable, Optional, Sequence, Tuple, Union +import numpy as np + from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin @@ -11,7 +13,7 @@ from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default from .elemwise import _elemwise_multi_type, clip -from .tensor import expand_dims, squeeze +from .tensor import broadcast_to, expand_dims, squeeze __all__ = [ "argmax", @@ -55,11 +57,22 @@ def isnan(inp: Tensor) -> Tensor: The returned array should have a data type of bool. Examples: + + >>> F.isnan(Tensor(1)) + Tensor(False, dtype=bool, device=xpux:0) + + .. TODO: Remove these comments when _elemwise_multi_type support scalar input + .. >>> F.isnan(Tensor(float("nan"))) + .. Tensor(True, dtype=bool, device=xpux:0) + + Element-wise isnan: >>> x = Tensor([1, float("nan"), 0]) >>> F.isnan(x) Tensor([False True False], dtype=bool, device=xpux:0) """ + if not np.issubdtype(inp.dtype, np.floating): + return broadcast_to(Tensor(False), inp.shape) return _elemwise_multi_type(inp, mode="isnan", dtype="bool") @@ -79,10 +92,21 @@ def isinf(inp: Tensor) -> Tensor: Examples: + >>> F.isinf(Tensor(1)) + Tensor(False, dtype=bool, device=xpux:0) + + .. TODO: Remove these comments when _elemwise_multi_type support scalar input + .. >>> F.isinf(Tensor(float("inf"))) + .. Tensor(True, dtype=bool, device=xpux:0) + + Element-wise isinf: + >>> x = Tensor([1, float("inf"), 0]) >>> F.isinf(x) Tensor([False True False], dtype=bool, device=xpux:0) """ + if not np.issubdtype(inp.dtype, np.floating): + return broadcast_to(Tensor(False), inp.shape) return _elemwise_multi_type(inp, mode="isinf", dtype="bool")