提交 034c7787 编写于 作者: M Megvii Engine Team

fix(mge/functional): support non-float32 input when call isxxx

GitOrigin-RevId: ea8f394958f2789b142767b065ba541053dbab66
上级 e9cc5237
...@@ -3,6 +3,8 @@ import collections ...@@ -3,6 +3,8 @@ import collections
import math import math
from typing import Iterable, Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.core2 import Const, apply
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
...@@ -11,7 +13,7 @@ from ..core.tensor.utils import _normalize_axis ...@@ -11,7 +13,7 @@ from ..core.tensor.utils import _normalize_axis
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default from ..utils.deprecation import deprecated_kwargs_default
from .elemwise import _elemwise_multi_type, clip from .elemwise import _elemwise_multi_type, clip
from .tensor import expand_dims, squeeze from .tensor import broadcast_to, expand_dims, squeeze
__all__ = [ __all__ = [
"argmax", "argmax",
...@@ -56,10 +58,21 @@ def isnan(inp: Tensor) -> Tensor: ...@@ -56,10 +58,21 @@ def isnan(inp: Tensor) -> Tensor:
Examples: Examples:
>>> F.isnan(Tensor(1))
Tensor(False, dtype=bool, device=xpux:0)
.. TODO: Remove these comments when _elemwise_multi_type support scalar input
.. >>> F.isnan(Tensor(float("nan")))
.. Tensor(True, dtype=bool, device=xpux:0)
Element-wise isnan:
>>> x = Tensor([1, float("nan"), 0]) >>> x = Tensor([1, float("nan"), 0])
>>> F.isnan(x) >>> F.isnan(x)
Tensor([False True False], dtype=bool, device=xpux:0) Tensor([False True False], dtype=bool, device=xpux:0)
""" """
if not np.issubdtype(inp.dtype, np.floating):
return broadcast_to(Tensor(False), inp.shape)
return _elemwise_multi_type(inp, mode="isnan", dtype="bool") return _elemwise_multi_type(inp, mode="isnan", dtype="bool")
...@@ -79,10 +92,21 @@ def isinf(inp: Tensor) -> Tensor: ...@@ -79,10 +92,21 @@ def isinf(inp: Tensor) -> Tensor:
Examples: Examples:
>>> F.isinf(Tensor(1))
Tensor(False, dtype=bool, device=xpux:0)
.. TODO: Remove these comments when _elemwise_multi_type support scalar input
.. >>> F.isinf(Tensor(float("inf")))
.. Tensor(True, dtype=bool, device=xpux:0)
Element-wise isinf:
>>> x = Tensor([1, float("inf"), 0]) >>> x = Tensor([1, float("inf"), 0])
>>> F.isinf(x) >>> F.isinf(x)
Tensor([False True False], dtype=bool, device=xpux:0) Tensor([False True False], dtype=bool, device=xpux:0)
""" """
if not np.issubdtype(inp.dtype, np.floating):
return broadcast_to(Tensor(False), inp.shape)
return _elemwise_multi_type(inp, mode="isinf", dtype="bool") return _elemwise_multi_type(inp, mode="isinf", dtype="bool")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册