From 938152af635595da70b9646a5699056fc6d37149 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 Jan 2021 14:32:20 +0800 Subject: [PATCH] fix(mge/functional): convert input type to float32 for more elemwise op GitOrigin-RevId: cf3bf8cb805a3229700dd2939393a3994bc59f35 --- .../megengine/core/tensor/array_method.py | 37 +++++++++++------ .../python/megengine/functional/elemwise.py | 40 ++++++++++++++----- .../test/unit/functional/test_elemwise.py | 18 +++++++++ 3 files changed, 72 insertions(+), 23 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index ed032ae15..689e76082 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -27,9 +27,31 @@ from .utils import setscalar _ElwMod = Elemwise.Mode -def _elwise(*args, mode): +def _elwise_apply(args, mode): op = builtin.Elemwise(mode) - if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW): + _isscalar = True + for i in args: + if isscalar(i) == False: + _isscalar = False + break + (result,) = apply(op, *args) + if _isscalar: + setscalar(result) + return result + + +def _elwise(*args, mode): + if mode in ( + _ElwMod.TRUE_DIV, + _ElwMod.POW, + _ElwMod.CEIL, + _ElwMod.FLOOR, + _ElwMod.ROUND, + ): + if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype( + args[0].dtype, np.integer + ): + return args[0] args = tuple( map( lambda x: x.astype("float32") @@ -39,16 +61,7 @@ def _elwise(*args, mode): ) ) args = utils.convert_inputs(*args) - (result,) = apply(op, *args) - - _isscalar = True - for i in args: - if isscalar(i) == False: - _isscalar = False - break - if _isscalar: - setscalar(result) - return result + return _elwise_apply(args, mode) def _matmul(inp1, inp2): diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 7f858d236..158e826aa 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -9,10 +9,13 @@ # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order import functools +import numpy as np + from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor import megbrain_graph, utils +from ..core.tensor.array_method import _elwise_apply from ..core.tensor.utils import isscalar, setscalar from ..device import get_default_device from ..jit.tracing import is_tracing @@ -74,7 +77,6 @@ __all__ = [ def _elwise(*args, mode): - op = builtin.Elemwise(mode) tensor_args = list( filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) ) @@ -84,17 +86,33 @@ def _elwise(*args, mode): args = utils.convert_inputs(first_arg, *args[1:]) else: args = utils.convert_inputs(*args) - if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): + if mode in ( + Elemwise.Mode.TRUE_DIV, + Elemwise.Mode.EXP, + Elemwise.Mode.POW, + Elemwise.Mode.LOG, + Elemwise.Mode.EXPM1, + Elemwise.Mode.LOG1P, + Elemwise.Mode.TANH, + Elemwise.Mode.ACOS, + Elemwise.Mode.ASIN, + Elemwise.Mode.ATAN2, + Elemwise.Mode.CEIL, + Elemwise.Mode.COS, + Elemwise.Mode.FLOOR, + Elemwise.Mode.H_SWISH, + Elemwise.Mode.ROUND, + Elemwise.Mode.SIGMOID, + Elemwise.Mode.SIN, + ): + if mode in ( + Elemwise.Mode.CEIL, + Elemwise.Mode.FLOOR, + Elemwise.Mode.ROUND, + ) and np.issubdtype(args[0].dtype, np.integer): + return args[0] args = tuple(map(lambda x: x.astype("float32"), args)) - _isscalar = True - for i in args: - if isscalar(i) == False: - _isscalar = False - break - (result,) = apply(op, *args) - if _isscalar: - setscalar(result) - return result + return _elwise_apply(args, mode) def _elemwise_multi_type(*args, mode, **kwargs): diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 88ed25382..7adfdfc97 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -9,6 +9,7 @@ import numpy as np import megengine.functional as F +import megengine.functional.elemwise as elemwise from megengine import tensor from megengine.core.tensor import dtype from megengine.functional.elemwise import _elwise @@ -166,3 +167,20 @@ def test_qadd(): result_mge = result_mge.astype("float32").numpy() result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() np.testing.assert_almost_equal(result_mge, result_expect, decimal=6) + + +def test_int32_input(): + x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32") + for op_name in elemwise.__all__: + op = getattr(elemwise, op_name) + nargs = op.__code__.co_argcount + if op_name == "clip": + inp = (x, 0, 1) + elif op_name.endswith("_shift"): + inp = (x, 1) + elif op_name.startswith("logical_"): + continue + else: + inp = (x,) * nargs + y = op(*inp) + y.numpy() -- GitLab