提交 ca4c93de 编写于 作者: M Megvii Engine Team

fix(mge): fix layer norm amp bug

GitOrigin-RevId: dba691fcbfc991508bd2fa409b294c6e39448dd1
上级 c674bf0e
......@@ -1090,6 +1090,9 @@ def layer_norm(
eps_mode
)
if amp._enabled:
inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)
_device = inp.device
_dtype = inp.dtype
_dim = len(inp.shape) - len(normalized_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册