diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 46493ee7bf738831b92c23dd36fa5b6d815c8d66..f05e619b2383b6c5ec39011b1c0e34c9c8080dd8 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -20,6 +20,7 @@ from .embedding import Embedding from .identity import Identity from .linear import Linear from .module import Module +from .normalization import GroupNorm, InstanceNorm, LayerNorm from .pooling import AvgPool2d, MaxPool2d from .quant_dequant import DequantStub, QuantStub from .sequential import Sequential diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 8b28eba853dbf33d20db46c4a4907e5543ef59df..b3a293660349dd6f6081b3fe25dceab7dd945af2 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -8,11 +8,13 @@ import numpy as np import megengine.functional as F -import megengine.module as M from megengine import Parameter +from .init import ones_, zeros_ +from .module import Module -class GroupNorm(M.Module): + +class GroupNorm(Module): """ Simple implementation of GroupNorm. Only support 4d tensor now. Reference: https://arxiv.org/pdf/1803.08494.pdf. @@ -35,8 +37,8 @@ class GroupNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape @@ -61,7 +63,7 @@ class GroupNorm(M.Module): return s.format(**self.__dict__) -class InstanceNorm(M.Module): +class InstanceNorm(Module): """ Simple implementation of InstanceNorm. Only support 4d tensor now. Reference: https://arxiv.org/abs/1607.08022. @@ -83,8 +85,8 @@ class InstanceNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape @@ -105,7 +107,7 @@ class InstanceNorm(M.Module): return s.format(**self.__dict__) -class LayerNorm(M.Module): +class LayerNorm(Module): """ Simple implementation of LayerNorm. Only support 4d tensor now. Reference: https://arxiv.org/pdf/1803.08494.pdf. @@ -127,8 +129,8 @@ class LayerNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape