From aa953c3bd637ccd4dda030af6143b5d7324525ed Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Nov 2020 11:23:21 +0800 Subject: [PATCH] fix(mge/module): fix missing import GitOrigin-RevId: 7cdf6ac81400a14d7ad585092a94ade5c16d5ca0 --- .../python/megengine/module/__init__.py | 1 + .../python/megengine/module/normalization.py | 22 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 46493ee7b..f05e619b2 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -20,6 +20,7 @@ from .embedding import Embedding from .identity import Identity from .linear import Linear from .module import Module +from .normalization import GroupNorm, InstanceNorm, LayerNorm from .pooling import AvgPool2d, MaxPool2d from .quant_dequant import DequantStub, QuantStub from .sequential import Sequential diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 8b28eba85..b3a293660 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -8,11 +8,13 @@ import numpy as np import megengine.functional as F -import megengine.module as M from megengine import Parameter +from .init import ones_, zeros_ +from .module import Module -class GroupNorm(M.Module): + +class GroupNorm(Module): """ Simple implementation of GroupNorm. Only support 4d tensor now. Reference: https://arxiv.org/pdf/1803.08494.pdf. @@ -35,8 +37,8 @@ class GroupNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape @@ -61,7 +63,7 @@ class GroupNorm(M.Module): return s.format(**self.__dict__) -class InstanceNorm(M.Module): +class InstanceNorm(Module): """ Simple implementation of InstanceNorm. Only support 4d tensor now. Reference: https://arxiv.org/abs/1607.08022. @@ -83,8 +85,8 @@ class InstanceNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape @@ -105,7 +107,7 @@ class InstanceNorm(M.Module): return s.format(**self.__dict__) -class LayerNorm(M.Module): +class LayerNorm(Module): """ Simple implementation of LayerNorm. Only support 4d tensor now. Reference: https://arxiv.org/pdf/1803.08494.pdf. @@ -127,8 +129,8 @@ class LayerNorm(M.Module): def reset_parameters(self): if self.affine: - M.init.ones_(self.weight) - M.init.zeros_(self.bias) + ones_(self.weight) + zeros_(self.bias) def forward(self, x): N, C, H, W = x.shape -- GitLab