diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 899f51eca5994b03c25e7e7199175ca985b45628..f524d03eb0f6119e7faa07bdc988cfff70f42586 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -109,22 +109,24 @@ class InstanceNorm(Module): class LayerNorm(Module): """ - Simple implementation of LayerNorm. Only support 4d tensor now. + Simple implementation of LayerNorm. Support tensor of any shape as input. Reference: https://arxiv.org/pdf/1803.08494.pdf. - Note that LayerNorm equals using GroupNorm with num_groups=1. """ - def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): + def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs): super().__init__(**kwargs) - self.num_channels = num_channels + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) self.eps = eps self.affine = affine if self.affine: - self.weight = Parameter(np.ones(num_channels, dtype="float32")) - self.bias = Parameter(np.zeros(num_channels, dtype="float32")) + self.weight = Parameter(np.ones(self.normalized_shape, dtype="float32")) + self.bias = Parameter(np.zeros(self.normalized_shape, dtype="float32")) else: self.weight = None self.bias = None + self.reset_parameters() def reset_parameters(self): @@ -133,20 +135,21 @@ class LayerNorm(Module): zeros_(self.bias) def forward(self, x): - N, C, H, W = x.shape - assert C == self.num_channels - x = x.reshape(x.shape[0], -1) - # NOTE mean will keepdims in next two lines. - mean = x.mean(axis=1, keepdims=1) - var = (x ** 2).mean(axis=1, keepdims=1) - mean * mean + x_shape = x.shape + assert x_shape[-len(self.normalized_shape) :] == self.normalized_shape + dim_delta = len(x_shape) - len(self.normalized_shape) + non_flatten_shape = x_shape[:dim_delta] + x = x.reshape(*non_flatten_shape, -1) + + mean = x.mean(axis=-1, keepdims=True) + var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean x = (x - mean) / F.sqrt(var + self.eps) - x = x.reshape(N, C, H, W) + x = x.reshape(x_shape) if self.affine: - x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) - + x = self.weight * x + self.bias return x def _module_info_string(self) -> str: - s = "channels={num_channels}, eps={eps}, affine={affine}" + s = "normalized_shape={normalized_shape}, eps={eps}, affine={affine}" return s.format(**self.__dict__) diff --git a/imperative/python/test/unit/module/test_normalization.py b/imperative/python/test/unit/module/test_normalization.py index f314f23d1c16c8b3263ee2a792dbc5fd4a76d924..130ed97537c8f88139db287ad129c1fa1e86f72d 100644 --- a/imperative/python/test/unit/module/test_normalization.py +++ b/imperative/python/test/unit/module/test_normalization.py @@ -30,12 +30,12 @@ def test_group_norm(): def test_layer_norm(): - input_shape = (2, 100, 128, 128) - channels = input_shape[1] - x = tensor(np.random.rand(*input_shape)) - ln = norm.LayerNorm(channels) - out = ln(x) - assert shape_to_tuple(out.shape) == input_shape + input_shape_list = [(2, 3, 10, 10), (2, 2, 3, 10, 10)] + ln = norm.LayerNorm((10, 10)) + for input_shape in input_shape_list: + x = tensor(np.random.rand(*input_shape)) + out = ln(x) + assert shape_to_tuple(out.shape) == input_shape def test_instance_norm():