提交 42711308 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(module/normalization): fix bug of LayerNorm and support input of any shape

GitOrigin-RevId: fd643addb5da4fd812df4461e67dbd3550674b32
上级 a95f6d4f
......@@ -109,22 +109,24 @@ class InstanceNorm(Module):
class LayerNorm(Module):
"""
Simple implementation of LayerNorm. Only support 4d tensor now.
Simple implementation of LayerNorm. Support tensor of any shape as input.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
Note that LayerNorm equals using GroupNorm with num_groups=1.
"""
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
def __init__(self, normalized_shape, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype="float32"))
self.bias = Parameter(np.zeros(num_channels, dtype="float32"))
self.weight = Parameter(np.ones(self.normalized_shape, dtype="float32"))
self.bias = Parameter(np.zeros(self.normalized_shape, dtype="float32"))
else:
self.weight = None
self.bias = None
self.reset_parameters()
def reset_parameters(self):
......@@ -133,20 +135,21 @@ class LayerNorm(Module):
zeros_(self.bias)
def forward(self, x):
N, C, H, W = x.shape
assert C == self.num_channels
x = x.reshape(x.shape[0], -1)
# NOTE mean will keepdims in next two lines.
mean = x.mean(axis=1, keepdims=1)
var = (x ** 2).mean(axis=1, keepdims=1) - mean * mean
x_shape = x.shape
assert x_shape[-len(self.normalized_shape) :] == self.normalized_shape
dim_delta = len(x_shape) - len(self.normalized_shape)
non_flatten_shape = x_shape[:dim_delta]
x = x.reshape(*non_flatten_shape, -1)
mean = x.mean(axis=-1, keepdims=True)
var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
x = x.reshape(x_shape)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
x = self.weight * x + self.bias
return x
def _module_info_string(self) -> str:
s = "channels={num_channels}, eps={eps}, affine={affine}"
s = "normalized_shape={normalized_shape}, eps={eps}, affine={affine}"
return s.format(**self.__dict__)
......@@ -30,12 +30,12 @@ def test_group_norm():
def test_layer_norm():
input_shape = (2, 100, 128, 128)
channels = input_shape[1]
x = tensor(np.random.rand(*input_shape))
ln = norm.LayerNorm(channels)
out = ln(x)
assert shape_to_tuple(out.shape) == input_shape
input_shape_list = [(2, 3, 10, 10), (2, 2, 3, 10, 10)]
ln = norm.LayerNorm((10, 10))
for input_shape in input_shape_list:
x = tensor(np.random.rand(*input_shape))
out = ln(x)
assert shape_to_tuple(out.shape) == input_shape
def test_instance_norm():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册