提交 47db29aa 编写于 作者: M Megvii Engine Team

fix(mge/module): add kwargs param for all modules

GitOrigin-RevId: 7245e669a7d5bcf718d448d9a59e7b31e8ec52d2
上级 6fb19b66
......@@ -48,8 +48,8 @@ class Softmax(Module):
"""
def __init__(self, axis=None):
super().__init__()
def __init__(self, axis=None, **kwargs):
super().__init__(**kwargs)
self.axis = axis
def forward(self, inputs):
......@@ -167,8 +167,8 @@ class PReLU(Module):
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25):
super().__init__()
def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs):
super().__init__(**kwargs)
self.num_parameters = num_parameters
if num_parameters > 1:
# Assume format is NCHW
......@@ -225,8 +225,8 @@ class LeakyReLU(Module):
"""
def __init__(self, negative_slope: float = 0.01):
super().__init__()
def __init__(self, negative_slope: float = 0.01, **kwargs):
super().__init__(**kwargs)
self.negative_slope = negative_slope
def forward(self, inputs):
......
......@@ -15,10 +15,8 @@ from .module import Module
class _AdaptivePoolNd(Module):
def __init__(
self, oshp: Union[Tuple[int, int], int, Tensor],
):
super(_AdaptivePoolNd, self).__init__()
def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs):
super(_AdaptivePoolNd, self).__init__(**kwargs)
self.oshp = oshp
@abstractmethod
......
......@@ -26,8 +26,9 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
**kwargs
):
super(_BatchNorm, self).__init__()
super(_BatchNorm, self).__init__(**kwargs)
self.num_features = num_features
self.eps = eps
self.momentum = momentum
......@@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats=True,
freeze=False,
group: Optional[Group] = WORLD,
**kwargs
) -> None:
super().__init__(
num_features, eps, momentum, affine, track_running_stats, freeze
num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
)
self.group = group
......
......@@ -37,8 +37,9 @@ class _ConvNd(Module):
dilation: Union[int, Tuple[int, int]],
groups: int,
bias: bool = True,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0:
......@@ -176,6 +177,7 @@ class Conv1d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = kernel_size
stride = stride
......@@ -192,6 +194,7 @@ class Conv1d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)
def _get_fanin(self):
......@@ -334,6 +337,7 @@ class Conv2d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
......@@ -350,6 +354,7 @@ class Conv2d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)
def _get_fanin(self):
......@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd):
bias: bool = True,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
**kwargs
):
kernel_size = _pair_nonzero(kernel_size)
stride = _pair_nonzero(stride)
......@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd):
dilation,
groups,
bias,
**kwargs,
)
def _get_fanin(self):
......@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d):
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
**kwargs
):
self.input_height = input_height
self.input_width = input_width
......@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d):
dilation,
groups,
bias=False,
**kwargs,
)
def _infer_weight_shape(self):
......
......@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module):
momentum=0.9,
affine=True,
track_running_stats=True,
**kwargs
):
super().__init__()
self.conv = Conv2d(
......@@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module):
bias,
conv_mode,
compute_mode,
**kwargs,
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)
......
......@@ -20,8 +20,8 @@ class Dropout(Module):
:param drop_prob: The probability to drop (set to zero) each single element
"""
def __init__(self, drop_prob=0.0):
super().__init__()
def __init__(self, drop_prob=0.0, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prob
def forward(self, inputs):
......
......@@ -72,8 +72,8 @@ class Elemwise(Module):
* "NOT": bool unary: ~x
"""
def __init__(self, method):
super().__init__()
def __init__(self, method, **kwargs):
super().__init__(**kwargs)
self.method = method
def forward(self, *inps):
......
......@@ -64,8 +64,9 @@ class Embedding(Module):
norm_type: Optional[float] = None,
initial_weight: Parameter = None,
freeze: bool = False,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
if padding_idx is not None:
raise ValueError("Not support padding index now.")
if max_norm is not None or norm_type is not None:
......
......@@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module):
See :func:`~.tensorrt_runtime_opr` for more details.
"""
def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
def __init__(self, data, **kwargs):
super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
self._data = data
@property
......
......@@ -20,8 +20,8 @@ class GroupNorm(Module):
Reference: https://arxiv.org/pdf/1803.08494.pdf.
"""
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super().__init__()
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs)
assert num_channels % num_groups == 0
self.num_groups = num_groups
self.num_channels = num_channels
......@@ -70,8 +70,8 @@ class InstanceNorm(Module):
Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
"""
def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
self.eps = eps
self.affine = affine
......@@ -114,8 +114,8 @@ class LayerNorm(Module):
Note that LayerNorm equals using GroupNorm with num_groups=1.
"""
def __init__(self, num_channels, eps=1e-05, affine=True):
super().__init__()
def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs):
super().__init__(**kwargs)
self.num_channels = num_channels
self.eps = eps
self.affine = affine
......
......@@ -19,8 +19,9 @@ class _PoolNd(Module):
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = None,
padding: Union[int, Tuple[int, int]] = 0,
**kwargs
):
super(_PoolNd, self).__init__()
super(_PoolNd, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
......
......@@ -46,8 +46,8 @@ class Sequential(Module):
pred1 = net1(data)
"""
def __init__(self, *args):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.layer_keys = []
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册