diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index db467791f5aba28330dbcb483cd9f6abffe0ed18..2efe5d42c15668336caed596ce319c56e5eeb88c 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -19,6 +19,7 @@ from ..core.ops.builtin import ( GetVarShape, Identity, Reduce, + Reshape, TypeCvt, ) from ..core.ops.special import Const @@ -1022,6 +1023,92 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return cached / down +@lru_cache(maxsize=None) +def _get_layerNorm(device, dtype, dim, gopt_level=2): + @subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) + def layerNormAffine(inputs, f, c): + inp, eps, _flatten_shape, weight, bias = inputs + inp_shape = f(GetVarShape(), inp) + + inp = f(Reshape(axis=dim), inp, _flatten_shape) + mean = f(Reduce(mode="mean", axis=-1), inp) + x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) + reduce_shape = f(GetVarShape(), x2s) + reduce_size = f( + "//", + f(Reduce(mode="product", axis=0), inp_shape), + f(Reduce(mode="product", axis=0), reduce_shape), + ) + reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) + var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) + inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) + oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) + affine_oup = f(Reshape(), oup, inp_shape) + affine_oup = f("fma3", affine_oup, weight, bias) + + # NOTE: return oup make backward faster but take more memory + return (affine_oup, oup, mean, x2s), (True, False, False, False) + + @subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) + def layerNorm(inputs, f, c): + inp, eps, _flatten_shape = inputs + inp_shape = f(GetVarShape(), inp) + + inp = f(Reshape(axis=dim), inp, _flatten_shape) + mean = f(Reduce(mode="mean", axis=-1), inp) + x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) + reduce_shape = f(GetVarShape(), x2s) + reduce_size = f( + "//", + f(Reduce(mode="product", axis=0), inp_shape), + f(Reduce(mode="product", axis=0), reduce_shape), + ) + reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) + var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) + inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) + oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) + oup = f(Reshape(), oup, inp_shape) + + return (oup,), (True,) + + return (layerNorm, layerNormAffine) + + +def layer_norm( + inp: Tensor, + normalized_shape: tuple, + affine: bool, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, + eps_mode="additive", +): + + assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( + eps_mode + ) + + _device = inp.device + _dtype = inp.dtype + _dim = len(inp.shape) - len(normalized_shape) + + _flatten_shape = concat( + ( + convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), + convert_single_value(-1, dtype="int32", device=inp.device), + ) + ) + (layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) + + eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) + if affine: + outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) + else: + outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) + + return outvar + + def batch_norm( inp: Tensor, running_mean: Tensor = None, diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index f2028dbfeba570c285af784204e5ccc1f5a731a2..1868727e0b698325fd5556fef4e0ff8fa314db18 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -132,18 +132,9 @@ class LayerNorm(Module): zeros_(self.bias) def forward(self, x): - x_shape = x.shape - dim_delta = len(x_shape) - len(self.normalized_shape) - non_flatten_shape = x_shape[:dim_delta] - x = x.reshape(*non_flatten_shape, -1) - - mean = x.mean(axis=-1, keepdims=True) - var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean - - x = (x - mean) / F.sqrt(var + self.eps) - x = x.reshape(x_shape) - if self.affine: - x = self.weight * x + self.bias + x = F.nn.layer_norm( + x, self.normalized_shape, self.affine, self.weight, self.bias, self.eps + ) return x def _module_info_string(self) -> str: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index d21363a50868fb7031c280890dfacb9fbd824326..27a8bd07df9bd05aa9be2a39d99663a9ca22b7fe 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -24,6 +24,7 @@ from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad from megengine.core.tensor.utils import make_shape_tuple from megengine.device import get_device_count +from megengine.module import LayerNorm def test_where(): @@ -862,6 +863,61 @@ def test_conv1d(): ) +def test_layer_norm(): + def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5): + __layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine) + __layer_norm.weight = weight + __layer_norm.bias = bias + return __layer_norm(x) + + def _layer_norm_numpy( + x, normalized_shape, affine, weight=None, bias=None, eps=1e-5 + ): + x_shape = x.shape + dim_delta = len(x_shape) - len(normalized_shape) + non_flatten_shape = x_shape[:dim_delta] + x = x.reshape(*non_flatten_shape, -1) + + mean = x.mean(axis=-1, keepdims=True) + var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean + + x = (x - mean) / F.sqrt(var + eps) + x = x.reshape(x_shape) + if affine: + x = weight * x + bias + + return x + + normalized_shape = (28, 28) + inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32") + weight = Tensor(np.random.randn(28, 28), dtype="float32") + bias = Tensor(np.random.randn(28, 28), dtype="float32") + + inp_feat = inp_feat + 1 + weight = weight + 1 + bias = bias + + affine = False + + outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) + targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias) + + assert abs(outvar - targetvar).mean() < 1e-7 + + # no random, affine True + normalized_shape = (28, 28) + inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32") + weight = Tensor(np.ones((28, 28)), dtype="float32") + bias = Tensor(np.zeros((28, 28)), dtype="float32") + + affine = True + + outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) + targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias) + assert abs((outvar - targetvar).mean()) < 1e-7 + assert abs(outvar.mean()) < 1e-7 + + def test_batchnorm2d_io16c32(): amp.enabled = True inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)