From 56cb5d6a16e67dc0966a5902c66386e7a3085c93 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 26 Aug 2020 15:19:06 +0800 Subject: [PATCH] fix(mge/functional): int operations(div, exp, pow) GitOrigin-RevId: dc434fa7ec04fa23baf3b8e8927860c50c14188a --- .../python/megengine/core/tensor/tensor_wrapper.py | 9 +++++++++ imperative/python/megengine/functional/elemwise.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 011e7d11..0bf0d7ee 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -23,6 +23,15 @@ from .tensor import Tensor def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + if mode in ("TRUE_DIV", "POW"): + args = tuple( + map( + lambda x: x.astype("float32") + if hasattr(x, "dtype") and x.dtype != np.float32 + else x, + args, + ) + ) args = utils.convert_inputs(*args) (result,) = apply(op, *args) return result diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3b8ac1f0..f3a43733 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -76,6 +76,10 @@ __all__ = [ def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): + args = tuple( + map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) + ) args = utils.convert_inputs(*args) (result,) = apply(op, *args) return result -- GitLab