提交 aa953c3b 编写于 作者: M Megvii Engine Team

fix(mge/module): fix missing import

GitOrigin-RevId: 7cdf6ac81400a14d7ad585092a94ade5c16d5ca0
上级 5a01de78
...@@ -20,6 +20,7 @@ from .embedding import Embedding ...@@ -20,6 +20,7 @@ from .embedding import Embedding
from .identity import Identity from .identity import Identity
from .linear import Linear from .linear import Linear
from .module import Module from .module import Module
from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .pooling import AvgPool2d, MaxPool2d from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential from .sequential import Sequential
...@@ -8,11 +8,13 @@ ...@@ -8,11 +8,13 @@
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M
from megengine import Parameter from megengine import Parameter
from .init import ones_, zeros_
from .module import Module
class GroupNorm(M.Module):
class GroupNorm(Module):
""" """
Simple implementation of GroupNorm. Only support 4d tensor now. Simple implementation of GroupNorm. Only support 4d tensor now.
Reference: https://arxiv.org/pdf/1803.08494.pdf. Reference: https://arxiv.org/pdf/1803.08494.pdf.
...@@ -35,8 +37,8 @@ class GroupNorm(M.Module): ...@@ -35,8 +37,8 @@ class GroupNorm(M.Module):
def reset_parameters(self): def reset_parameters(self):
if self.affine: if self.affine:
M.init.ones_(self.weight) ones_(self.weight)
M.init.zeros_(self.bias) zeros_(self.bias)
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
...@@ -61,7 +63,7 @@ class GroupNorm(M.Module): ...@@ -61,7 +63,7 @@ class GroupNorm(M.Module):
return s.format(**self.__dict__) return s.format(**self.__dict__)
class InstanceNorm(M.Module): class InstanceNorm(Module):
""" """
Simple implementation of InstanceNorm. Only support 4d tensor now. Simple implementation of InstanceNorm. Only support 4d tensor now.
Reference: https://arxiv.org/abs/1607.08022. Reference: https://arxiv.org/abs/1607.08022.
...@@ -83,8 +85,8 @@ class InstanceNorm(M.Module): ...@@ -83,8 +85,8 @@ class InstanceNorm(M.Module):
def reset_parameters(self): def reset_parameters(self):
if self.affine: if self.affine:
M.init.ones_(self.weight) ones_(self.weight)
M.init.zeros_(self.bias) zeros_(self.bias)
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
...@@ -105,7 +107,7 @@ class InstanceNorm(M.Module): ...@@ -105,7 +107,7 @@ class InstanceNorm(M.Module):
return s.format(**self.__dict__) return s.format(**self.__dict__)
class LayerNorm(M.Module): class LayerNorm(Module):
""" """
Simple implementation of LayerNorm. Only support 4d tensor now. Simple implementation of LayerNorm. Only support 4d tensor now.
Reference: https://arxiv.org/pdf/1803.08494.pdf. Reference: https://arxiv.org/pdf/1803.08494.pdf.
...@@ -127,8 +129,8 @@ class LayerNorm(M.Module): ...@@ -127,8 +129,8 @@ class LayerNorm(M.Module):
def reset_parameters(self): def reset_parameters(self):
if self.affine: if self.affine:
M.init.ones_(self.weight) ones_(self.weight)
M.init.zeros_(self.bias) zeros_(self.bias)
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册