提交 7167fdbd 编写于 作者: M Megvii Engine Team

feat(mge/module): add normalization module includes group_norm, instance_norm and layer_norm

GitOrigin-RevId: 9f253e32501ab34ea5ee0bce43070a196aeec494
上级 94dba16f
......@@ -57,10 +57,10 @@ def _is_module(obj):
def _get_XNorm_typeclass():
from .batchnorm import _BatchNorm
from .normalization import GroupNorm, LayerNorm, InstanceNorm
XNorm_types = []
XNorm_types.append(_BatchNorm)
return tuple(XNorm_types)
XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm)
return XNorm_types
class Module(metaclass=ABCMeta):
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Parameter
class GroupNorm(M.Module):
"""
Simple implementation of GroupNorm.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
"""
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super().__init__()
assert num_channels % num_groups == 0
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype=np.float32))
self.bias = Parameter(np.zeros(num_channels, dtype=np.float32))
else:
self.weight = None
self.bias = None
self.reset_parameters()
def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
def forward(self, x):
N, C, H, W = x.shape
assert C == self.num_channels
x = x.reshape(N, self.num_groups, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x * x).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
return x
def _module_info_string(self) -> str:
s = (
"groups={num_groups}, channels={num_channels}, "
"eps={eps}, affine={affine}"
)
return s.format(**self.__dict__)
class InstanceNorm(M.Module):
"""
simple implementation of InstanceNorm.
Reference: https://arxiv.org/abs/1607.08022.
Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
"""
def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype="float32"))
self.bias = Parameter(np.zeros(num_channels, dtype="float32"))
else:
self.weight = None
self.bias = None
self.reset_parameters()
def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
def forward(self, x):
N, C, H, W = x.shape
assert C == self.num_channels
x = x.reshape(N, C, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x ** 2).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
return x
def _module_info_string(self) -> str:
s = "channels={num_channels}, eps={eps}, affine={affine}"
return s.format(**self.__dict__)
class LayerNorm(M.Module):
"""
simple implementation of LayerNorm.
Reference: https://arxiv.org/pdf/1803.08494.pdf.
Note that LayerNorm equals using GroupNorm with num_groups=1.
"""
def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype="float32"))
self.bias = Parameter(np.zeros(num_channels, dtype="float32"))
else:
self.weight = None
self.bias = None
self.reset_parameters()
def reset_parameters(self):
if self.affine:
M.init.ones_(self.weight)
M.init.zeros_(self.bias)
def forward(self, x):
N, C, H, W = x.shape
assert C == self.num_channels
x = x.reshape(x.shape[0], -1)
# NOTE mean will keepdims in next two lines.
mean = x.mean(axis=1, keepdims=1)
var = (x ** 2).mean(axis=1, keepdims=1) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
return x
def _module_info_string(self) -> str:
s = "channels={num_channels}, eps={eps}, affine={affine}"
return s.format(**self.__dict__)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine.module.normalization as norm
from megengine import tensor
def shape_to_tuple(shape):
if isinstance(shape, tensor):
shape = tuple(shape.tolist())
return shape
def test_group_norm():
input_shape = (2, 100, 128, 128)
channels = input_shape[1]
groups = [2, 5, 10, 50]
x = tensor(np.random.rand(*input_shape))
for group in groups:
gn = norm.GroupNorm(group, channels)
out = gn(x)
assert shape_to_tuple(out.shape) == input_shape
def test_layer_norm():
input_shape = (2, 100, 128, 128)
channels = input_shape[1]
x = tensor(np.random.rand(*input_shape))
ln = norm.LayerNorm(channels)
out = ln(x)
assert shape_to_tuple(out.shape) == input_shape
def test_instance_norm():
input_shape = (2, 100, 128, 128)
channels = input_shape[1]
x = tensor(np.random.rand(*input_shape))
inst_norm = norm.InstanceNorm(channels)
out = inst_norm(x)
assert shape_to_tuple(out.shape) == input_shape
......@@ -72,7 +72,7 @@ def test_TQT():
c1, c2 = f.backward(c)
c1_np, c2_np = nf.backward(c_np)
np.testing.assert_allclose(c1.numpy(), c1_np.astype("float32"), rtol=1e-6)
np.testing.assert_allclose(c2.numpy(), c2_np.astype("float32"), rtol=1e-6)
np.testing.assert_allclose(c2.numpy(), c2_np.astype("float32"), rtol=5e-5)
a_np = np.random.random((4, 3)).astype("float32")
b_np = np.random.random((1)).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册