diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 46cf27711d37261d0ffaed30e04bac8385b2ff13..ebcdbbcbf865e5f4bc353ac0712dc053ec6bba9f 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1090,6 +1090,9 @@ def layer_norm( eps_mode ) + if amp._enabled: + inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) + _device = inp.device _dtype = inp.dtype _dim = len(inp.shape) - len(normalized_shape)