提交 56cb5d6a 编写于 作者: M Megvii Engine Team

fix(mge/functional): int operations(div, exp, pow)

GitOrigin-RevId: dc434fa7ec04fa23baf3b8e8927860c50c14188a
上级 3840c1f4
......@@ -23,6 +23,15 @@ from .tensor import Tensor
def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode)
if mode in ("TRUE_DIV", "POW"):
args = tuple(
map(
lambda x: x.astype("float32")
if hasattr(x, "dtype") and x.dtype != np.float32
else x,
args,
)
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
return result
......
......@@ -76,6 +76,10 @@ __all__ = [
def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple(
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args)
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册