From 034c7787fa3e237a72c5ea5081ef2397e85b0e96 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 27 Sep 2022 20:18:55 +0800 Subject: [PATCH] fix(mge/functional): support non-float32 input when call isxxx GitOrigin-RevId: ea8f394958f2789b142767b065ba541053dbab66 --- .../python/megengine/functional/math.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 6a61989c6..35a49e256 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -3,6 +3,8 @@ import collections import math from typing import Iterable, Optional, Sequence, Tuple, Union +import numpy as np + from ..core._imperative_rt.core2 import Const, apply from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin @@ -11,7 +13,7 @@ from ..core.tensor.utils import _normalize_axis from ..tensor import Tensor from ..utils.deprecation import deprecated_kwargs_default from .elemwise import _elemwise_multi_type, clip -from .tensor import expand_dims, squeeze +from .tensor import broadcast_to, expand_dims, squeeze __all__ = [ "argmax", @@ -55,11 +57,22 @@ def isnan(inp: Tensor) -> Tensor: The returned array should have a data type of bool. Examples: + + >>> F.isnan(Tensor(1)) + Tensor(False, dtype=bool, device=xpux:0) + + .. TODO: Remove these comments when _elemwise_multi_type support scalar input + .. >>> F.isnan(Tensor(float("nan"))) + .. Tensor(True, dtype=bool, device=xpux:0) + + Element-wise isnan: >>> x = Tensor([1, float("nan"), 0]) >>> F.isnan(x) Tensor([False True False], dtype=bool, device=xpux:0) """ + if not np.issubdtype(inp.dtype, np.floating): + return broadcast_to(Tensor(False), inp.shape) return _elemwise_multi_type(inp, mode="isnan", dtype="bool") @@ -79,10 +92,21 @@ def isinf(inp: Tensor) -> Tensor: Examples: + >>> F.isinf(Tensor(1)) + Tensor(False, dtype=bool, device=xpux:0) + + .. TODO: Remove these comments when _elemwise_multi_type support scalar input + .. >>> F.isinf(Tensor(float("inf"))) + .. Tensor(True, dtype=bool, device=xpux:0) + + Element-wise isinf: + >>> x = Tensor([1, float("inf"), 0]) >>> F.isinf(x) Tensor([False True False], dtype=bool, device=xpux:0) """ + if not np.issubdtype(inp.dtype, np.floating): + return broadcast_to(Tensor(False), inp.shape) return _elemwise_multi_type(inp, mode="isinf", dtype="bool") -- GitLab