From ca4c93dee702ab109be31debfe0842a7bc460a01 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 28 Sep 2021 19:24:07 +0800 Subject: [PATCH] fix(mge): fix layer norm amp bug GitOrigin-RevId: dba691fcbfc991508bd2fa409b294c6e39448dd1 --- imperative/python/megengine/functional/nn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 46cf27711..ebcdbbcbf 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1090,6 +1090,9 @@ def layer_norm( eps_mode ) + if amp._enabled: + inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) + _device = inp.device _dtype = inp.dtype _dim = len(inp.shape) - len(normalized_shape) -- GitLab