diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py index c65f2f25f14e377a1ba697ba1b0da9ea4e567e6e..072d7be0c3f0f0c1fb80390eb170e26ee54de586 100644 --- a/imperative/python/megengine/module/activation.py +++ b/imperative/python/megengine/module/activation.py @@ -48,8 +48,8 @@ class Softmax(Module): """ - def __init__(self, axis=None): - super().__init__() + def __init__(self, axis=None, **kwargs): + super().__init__(**kwargs) self.axis = axis def forward(self, inputs): @@ -167,8 +167,8 @@ class PReLU(Module): """ - def __init__(self, num_parameters: int = 1, init: float = 0.25): - super().__init__() + def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs): + super().__init__(**kwargs) self.num_parameters = num_parameters if num_parameters > 1: # Assume format is NCHW @@ -225,8 +225,8 @@ class LeakyReLU(Module): """ - def __init__(self, negative_slope: float = 0.01): - super().__init__() + def __init__(self, negative_slope: float = 0.01, **kwargs): + super().__init__(**kwargs) self.negative_slope = negative_slope def forward(self, inputs): diff --git a/imperative/python/megengine/module/adaptive_pooling.py b/imperative/python/megengine/module/adaptive_pooling.py index 8061f21bbab6e37330f4e45ad1e085a061086e3f..44e33f43e58f7704e893a74b8ad07cb34a439a21 100644 --- a/imperative/python/megengine/module/adaptive_pooling.py +++ b/imperative/python/megengine/module/adaptive_pooling.py @@ -15,10 +15,8 @@ from .module import Module class _AdaptivePoolNd(Module): - def __init__( - self, oshp: Union[Tuple[int, int], int, Tensor], - ): - super(_AdaptivePoolNd, self).__init__() + def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs): + super(_AdaptivePoolNd, self).__init__(**kwargs) self.oshp = oshp @abstractmethod diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index f6d313d2a9011a207580c4001882c33cbd112ab4..1bc3fd95557c1c787cb56b0a38a41af7120ef999 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -26,8 +26,9 @@ class _BatchNorm(Module): affine=True, track_running_stats=True, freeze=False, + **kwargs ): - super(_BatchNorm, self).__init__() + super(_BatchNorm, self).__init__(**kwargs) self.num_features = num_features self.eps = eps self.momentum = momentum @@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm): track_running_stats=True, freeze=False, group: Optional[Group] = WORLD, + **kwargs ) -> None: super().__init__( - num_features, eps, momentum, affine, track_running_stats, freeze + num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs ) self.group = group diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 77b7a2bb8cf99f059e84e0b97bc1b9011ca9be64..1d25ca731f93ef81b6e721cfa6957bb41da90690 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -37,8 +37,9 @@ class _ConvNd(Module): dilation: Union[int, Tuple[int, int]], groups: int, bias: bool = True, + **kwargs ): - super().__init__() + super().__init__(**kwargs) if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: @@ -176,6 +177,7 @@ class Conv1d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = kernel_size stride = stride @@ -192,6 +194,7 @@ class Conv1d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -334,6 +337,7 @@ class Conv2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -350,6 +354,7 @@ class Conv2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -536,6 +543,7 @@ class LocalConv2d(Conv2d): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, conv_mode: str = "CROSS_CORRELATION", + **kwargs ): self.input_height = input_height self.input_width = input_width @@ -548,6 +556,7 @@ class LocalConv2d(Conv2d): dilation, groups, bias=False, + **kwargs, ) def _infer_weight_shape(self): diff --git a/imperative/python/megengine/module/conv_bn.py b/imperative/python/megengine/module/conv_bn.py index 354156c83eefe6e91c344794bb5aa82d0b8713c1..2616c6ec1bc040c7bf1b10c613c18f2adaf19e18 100644 --- a/imperative/python/megengine/module/conv_bn.py +++ b/imperative/python/megengine/module/conv_bn.py @@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): momentum=0.9, affine=True, track_running_stats=True, + **kwargs ): super().__init__() self.conv = Conv2d( @@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module): bias, conv_mode, compute_mode, + **kwargs, ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) diff --git a/imperative/python/megengine/module/dropout.py b/imperative/python/megengine/module/dropout.py index f7e783176f2f55d1e11aaf195ebea2a027855692..08587e91062aa38ae56d7d0a31d9fbe996904f72 100644 --- a/imperative/python/megengine/module/dropout.py +++ b/imperative/python/megengine/module/dropout.py @@ -20,8 +20,8 @@ class Dropout(Module): :param drop_prob: The probability to drop (set to zero) each single element """ - def __init__(self, drop_prob=0.0): - super().__init__() + def __init__(self, drop_prob=0.0, **kwargs): + super().__init__(**kwargs) self.drop_prob = drop_prob def forward(self, inputs): diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index 9713557ccdaede1fa757459620a752769d1cdce4..d600d3de73748b2c69a9e783650c8b6942457bf9 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -72,8 +72,8 @@ class Elemwise(Module): * "NOT": bool unary: ~x """ - def __init__(self, method): - super().__init__() + def __init__(self, method, **kwargs): + super().__init__(**kwargs) self.method = method def forward(self, *inps): diff --git a/imperative/python/megengine/module/embedding.py b/imperative/python/megengine/module/embedding.py index d5f93afd569ed6fd55c608fd128a4e5a36d47f0c..1f12f2494de6f3be780792c81e3d6f8cd7b19c02 100644 --- a/imperative/python/megengine/module/embedding.py +++ b/imperative/python/megengine/module/embedding.py @@ -64,8 +64,9 @@ class Embedding(Module): norm_type: Optional[float] = None, initial_weight: Parameter = None, freeze: bool = False, + **kwargs ): - super().__init__() + super().__init__(**kwargs) if padding_idx is not None: raise ValueError("Not support padding index now.") if max_norm is not None or norm_type is not None: diff --git a/imperative/python/megengine/module/external.py b/imperative/python/megengine/module/external.py index ae98c8c4fe65ce41dac7cb02ba5b0a936a969334..c28595d9491288911b3707c59820124b7a242554 100644 --- a/imperative/python/megengine/module/external.py +++ b/imperative/python/megengine/module/external.py @@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module): See :func:`~.tensorrt_runtime_opr` for more details. """ - def __init__( - self, data, - ): - super(TensorrtRuntimeSubgraph, self).__init__() + def __init__(self, data, **kwargs): + super(TensorrtRuntimeSubgraph, self).__init__(**kwargs) self._data = data @property diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 8c1a68490a30f79734702682cc90526fc9efdd13..899f51eca5994b03c25e7e7199175ca985b45628 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -20,8 +20,8 @@ class GroupNorm(Module): Reference: https://arxiv.org/pdf/1803.08494.pdf. """ - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - super().__init__() + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): + super().__init__(**kwargs) assert num_channels % num_groups == 0 self.num_groups = num_groups self.num_channels = num_channels @@ -70,8 +70,8 @@ class InstanceNorm(Module): Note that InstanceNorm equals using GroupNome with num_groups=num_channels. """ - def __init__(self, num_channels, eps=1e-05, affine=True): - super().__init__() + def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): + super().__init__(**kwargs) self.num_channels = num_channels self.eps = eps self.affine = affine @@ -114,8 +114,8 @@ class LayerNorm(Module): Note that LayerNorm equals using GroupNorm with num_groups=1. """ - def __init__(self, num_channels, eps=1e-05, affine=True): - super().__init__() + def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): + super().__init__(**kwargs) self.num_channels = num_channels self.eps = eps self.affine = affine diff --git a/imperative/python/megengine/module/pooling.py b/imperative/python/megengine/module/pooling.py index 56e0948cf661164f2e1a803366325a5e1d6266b1..c8ff81bea02b330d040d9ced073f3313bfba09a2 100644 --- a/imperative/python/megengine/module/pooling.py +++ b/imperative/python/megengine/module/pooling.py @@ -19,8 +19,9 @@ class _PoolNd(Module): kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = None, padding: Union[int, Tuple[int, int]] = 0, + **kwargs ): - super(_PoolNd, self).__init__() + super(_PoolNd, self).__init__(**kwargs) self.kernel_size = kernel_size self.stride = stride or kernel_size self.padding = padding diff --git a/imperative/python/megengine/module/sequential.py b/imperative/python/megengine/module/sequential.py index 1cce1d4c8aa9f200d973ea68f54d1805ec9d941c..e484110c3e66bc7e4a3a5d7aee25cee9fa3ed3fd 100644 --- a/imperative/python/megengine/module/sequential.py +++ b/imperative/python/megengine/module/sequential.py @@ -46,8 +46,8 @@ class Sequential(Module): pred1 = net1(data) """ - def __init__(self, *args): - super().__init__() + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) self.layer_keys = [] if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items():