diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 011e7d1183db67dd129a82051e42a160082295ae..0bf0d7ee68d9e67db64dd352baeb420c93e3ff87 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -23,6 +23,15 @@ from .tensor import Tensor def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + if mode in ("TRUE_DIV", "POW"): + args = tuple( + map( + lambda x: x.astype("float32") + if hasattr(x, "dtype") and x.dtype != np.float32 + else x, + args, + ) + ) args = utils.convert_inputs(*args) (result,) = apply(op, *args) return result diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3b8ac1f08f1230151d07be35bc1b14e270964386..f3a43733785c82a07e22a5ce976afbb46a341d60 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -76,6 +76,10 @@ __all__ = [ def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): + args = tuple( + map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) + ) args = utils.convert_inputs(*args) (result,) = apply(op, *args) return result