提交 9fd2e663 编写于 作者: M Megvii Engine Team

feat(mge/elwise): removed back to fp32 mode

GitOrigin-RevId: a665a279a684c41a5ff746c95e7a5246125035c1
上级 e2427f30
......@@ -47,24 +47,27 @@ def _elwise_apply(args, mode):
def _elwise(*args, mode):
args = convert_inputs(*args)
if mode in (
_ElwMod.TRUE_DIV,
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.TANH,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
if (
mode
in (
_ElwMod.EXP,
_ElwMod.POW,
_ElwMod.LOG,
_ElwMod.EXPM1,
_ElwMod.LOG1P,
_ElwMod.ACOS,
_ElwMod.ASIN,
_ElwMod.ATAN2,
_ElwMod.COS,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
)
and (
amp._enabled
or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
)
or mode in (_ElwMod.TRUE_DIV, _ElwMod.TANH,)
and np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册