From 7167fdbd499ada76831e2008b1e5d7410ec31b26 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 Aug 2020 14:58:57 +0800 Subject: [PATCH] feat(mge/module): add normalization module includes group_norm, instance_norm and layer_norm GitOrigin-RevId: 9f253e32501ab34ea5ee0bce43070a196aeec494 --- imperative/python/megengine/module/module.py | 6 +- .../python/megengine/module/normalization.py | 150 ++++++++++++++++++ .../test/unit/module/test_normalization.py | 47 ++++++ .../test/unit/quantization/test_fake_quant.py | 2 +- 4 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 imperative/python/megengine/module/normalization.py create mode 100644 imperative/python/test/unit/module/test_normalization.py diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index e400a6c4..41bb6720 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -57,10 +57,10 @@ def _is_module(obj): def _get_XNorm_typeclass(): from .batchnorm import _BatchNorm + from .normalization import GroupNorm, LayerNorm, InstanceNorm - XNorm_types = [] - XNorm_types.append(_BatchNorm) - return tuple(XNorm_types) + XNorm_types = (_BatchNorm, GroupNorm, LayerNorm, InstanceNorm) + return XNorm_types class Module(metaclass=ABCMeta): diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py new file mode 100644 index 00000000..bdd90f7a --- /dev/null +++ b/imperative/python/megengine/module/normalization.py @@ -0,0 +1,150 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine.functional as F +import megengine.module as M +from megengine import Parameter + + +class GroupNorm(M.Module): + """ + Simple implementation of GroupNorm. + Reference: https://arxiv.org/pdf/1803.08494.pdf. + """ + + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super().__init__() + assert num_channels % num_groups == 0 + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) + self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) + else: + self.weight = None + self.bias = None + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + M.init.ones_(self.weight) + M.init.zeros_(self.bias) + + def forward(self, x): + N, C, H, W = x.shape + assert C == self.num_channels + + x = x.reshape(N, self.num_groups, -1) + mean = x.mean(axis=2, keepdims=True) + var = (x * x).mean(axis=2, keepdims=True) - mean * mean + + x = (x - mean) / F.sqrt(var + self.eps) + x = x.reshape(N, C, H, W) + if self.affine: + x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) + + return x + + def _module_info_string(self) -> str: + s = ( + "groups={num_groups}, channels={num_channels}, " + "eps={eps}, affine={affine}" + ) + return s.format(**self.__dict__) + + +class InstanceNorm(M.Module): + """ + simple implementation of InstanceNorm. + Reference: https://arxiv.org/abs/1607.08022. + Note that InstanceNorm equals using GroupNome with num_groups=num_channels. + """ + + def __init__(self, num_channels, eps=1e-05, affine=True): + super().__init__() + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(np.ones(num_channels, dtype="float32")) + self.bias = Parameter(np.zeros(num_channels, dtype="float32")) + else: + self.weight = None + self.bias = None + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + M.init.ones_(self.weight) + M.init.zeros_(self.bias) + + def forward(self, x): + N, C, H, W = x.shape + assert C == self.num_channels + x = x.reshape(N, C, -1) + mean = x.mean(axis=2, keepdims=True) + var = (x ** 2).mean(axis=2, keepdims=True) - mean * mean + + x = (x - mean) / F.sqrt(var + self.eps) + x = x.reshape(N, C, H, W) + if self.affine: + x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) + + return x + + def _module_info_string(self) -> str: + s = "channels={num_channels}, eps={eps}, affine={affine}" + return s.format(**self.__dict__) + + +class LayerNorm(M.Module): + """ + simple implementation of LayerNorm. + Reference: https://arxiv.org/pdf/1803.08494.pdf. + Note that LayerNorm equals using GroupNorm with num_groups=1. + """ + + def __init__(self, num_channels, eps=1e-05, affine=True): + super().__init__() + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(np.ones(num_channels, dtype="float32")) + self.bias = Parameter(np.zeros(num_channels, dtype="float32")) + else: + self.weight = None + self.bias = None + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + M.init.ones_(self.weight) + M.init.zeros_(self.bias) + + def forward(self, x): + N, C, H, W = x.shape + assert C == self.num_channels + x = x.reshape(x.shape[0], -1) + # NOTE mean will keepdims in next two lines. + mean = x.mean(axis=1, keepdims=1) + var = (x ** 2).mean(axis=1, keepdims=1) - mean * mean + + x = (x - mean) / F.sqrt(var + self.eps) + x = x.reshape(N, C, H, W) + if self.affine: + x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) + + return x + + def _module_info_string(self) -> str: + s = "channels={num_channels}, eps={eps}, affine={affine}" + return s.format(**self.__dict__) diff --git a/imperative/python/test/unit/module/test_normalization.py b/imperative/python/test/unit/module/test_normalization.py new file mode 100644 index 00000000..49e02803 --- /dev/null +++ b/imperative/python/test/unit/module/test_normalization.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine.module.normalization as norm +from megengine import tensor + + +def shape_to_tuple(shape): + if isinstance(shape, tensor): + shape = tuple(shape.tolist()) + return shape + + +def test_group_norm(): + input_shape = (2, 100, 128, 128) + channels = input_shape[1] + groups = [2, 5, 10, 50] + x = tensor(np.random.rand(*input_shape)) + for group in groups: + gn = norm.GroupNorm(group, channels) + out = gn(x) + assert shape_to_tuple(out.shape) == input_shape + + +def test_layer_norm(): + input_shape = (2, 100, 128, 128) + channels = input_shape[1] + x = tensor(np.random.rand(*input_shape)) + ln = norm.LayerNorm(channels) + out = ln(x) + assert shape_to_tuple(out.shape) == input_shape + + +def test_instance_norm(): + input_shape = (2, 100, 128, 128) + channels = input_shape[1] + x = tensor(np.random.rand(*input_shape)) + inst_norm = norm.InstanceNorm(channels) + out = inst_norm(x) + assert shape_to_tuple(out.shape) == input_shape diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 59b366d5..60cda8c4 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -72,7 +72,7 @@ def test_TQT(): c1, c2 = f.backward(c) c1_np, c2_np = nf.backward(c_np) np.testing.assert_allclose(c1.numpy(), c1_np.astype("float32"), rtol=1e-6) - np.testing.assert_allclose(c2.numpy(), c2_np.astype("float32"), rtol=1e-6) + np.testing.assert_allclose(c2.numpy(), c2_np.astype("float32"), rtol=5e-5) a_np = np.random.random((4, 3)).astype("float32") b_np = np.random.random((1)).astype("float32") -- GitLab