From 47db29aaa24d4d1247e28ea590d507c6df111997 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 7 Feb 2021 18:09:24 +0800 Subject: [PATCH] fix(mge/module): add kwargs param for all modules GitOrigin-RevId: 7245e669a7d5bcf718d448d9a59e7b31e8ec52d2 --- imperative/python/megengine/module/activation.py | 12 ++++++------ .../python/megengine/module/adaptive_pooling.py | 6 ++---- imperative/python/megengine/module/batchnorm.py | 6 ++++-- imperative/python/megengine/module/conv.py | 11 ++++++++++- imperative/python/megengine/module/conv_bn.py | 2 ++ imperative/python/megengine/module/dropout.py | 4 ++-- imperative/python/megengine/module/elemwise.py | 4 ++-- imperative/python/megengine/module/embedding.py | 3 ++- imperative/python/megengine/module/external.py | 6 ++---- imperative/python/megengine/module/normalization.py | 12 ++++++------ imperative/python/megengine/module/pooling.py | 3 ++- imperative/python/megengine/module/sequential.py | 4 ++-- 12 files changed, 42 insertions(+), 31 deletions(-) diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py index c65f2f25f..072d7be0c 100644 --- a/imperative/python/megengine/module/activation.py +++ b/imperative/python/megengine/module/activation.py @@ -48,8 +48,8 @@ class Softmax(Module): """ - def __init__(self, axis=None): - super().__init__() + def __init__(self, axis=None, **kwargs): + super().__init__(**kwargs) self.axis = axis def forward(self, inputs): @@ -167,8 +167,8 @@ class PReLU(Module): """ - def __init__(self, num_parameters: int = 1, init: float = 0.25): - super().__init__() + def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs): + super().__init__(**kwargs) self.num_parameters = num_parameters if num_parameters > 1: # Assume format is NCHW @@ -225,8 +225,8 @@ class LeakyReLU(Module): """ - def __init__(self, negative_slope: float = 0.01): - super().__init__() + def __init__(self, negative_slope: float = 0.01, **kwargs): + super().__init__(**kwargs) self.negative_slope = negative_slope def forward(self, inputs): diff --git a/imperative/python/megengine/module/adaptive_pooling.py b/imperative/python/megengine/module/adaptive_pooling.py index 8061f21bb..44e33f43e 100644 --- a/imperative/python/megengine/module/adaptive_pooling.py +++ b/imperative/python/megengine/module/adaptive_pooling.py @@ -15,10 +15,8 @@ from .module import Module class _AdaptivePoolNd(Module): - def __init__( - self, oshp: Union[Tuple[int, int], int, Tensor], - ): - super(_AdaptivePoolNd, self).__init__() + def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs): + super(_AdaptivePoolNd, self).__init__(**kwargs) self.oshp = oshp @abstractmethod diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index f6d313d2a..1bc3fd955 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -26,8 +26,9 @@ class _BatchNorm(Module): affine=True, track_running_stats=True, freeze=False, + **kwargs ): - super(_BatchNorm, self).__init__() + super(_BatchNorm, self).__init__(**kwargs) self.num_features = num_features self.eps = eps self.momentum = momentum @@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm): track_running_stats=True, freeze=False, group: Optional[Group] = WORLD, + **kwargs ) -> None: super().__init__( - num_features, eps, momentum, affine, track_running_stats, freeze + num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs ) self.group = group diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 77b7a2bb8..1d25ca731 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -37,8 +37,9 @@ class _ConvNd(Module): dilation: Union[int, Tuple[int, int]], groups: int, bias: bool = True, + **kwargs ): - super().__init__() + super().__init__(**kwargs) if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: @@ -176,6 +177,7 @@ class Conv1d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = kernel_size stride = stride @@ -192,6 +194,7 @@ class Conv1d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -334,6 +337,7 @@ class Conv2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -350,6 +354,7 @@ class Conv2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd): bias: bool = True, conv_mode: str = "CROSS_CORRELATION", compute_mode: str = "DEFAULT", + **kwargs ): kernel_size = _pair_nonzero(kernel_size) stride = _pair_nonzero(stride) @@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd): dilation, groups, bias, + **kwargs, ) def _get_fanin(self): @@ -536,6 +543,7 @@ class LocalConv2d(Conv2d): dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, conv_mode: str = "CROSS_CORRELATION", + **kwargs ): self.input_height = input_height self.input_width = input_width @@ -548,6 +556,7 @@ class LocalConv2d(Conv2d): dilation, groups, bias=False, + **kwargs, ) def _infer_weight_shape(self): diff --git a/imperative/python/megengine/module/conv_bn.py b/imperative/python/megengine/module/conv_bn.py index 354156c83..2616c6ec1 100644 --- a/imperative/python/megengine/module/conv_bn.py +++ b/imperative/python/megengine/module/conv_bn.py @@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): momentum=0.9, affine=True, track_running_stats=True, + **kwargs ): super().__init__() self.conv = Conv2d( @@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module): bias, conv_mode, compute_mode, + **kwargs, ) self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) diff --git a/imperative/python/megengine/module/dropout.py b/imperative/python/megengine/module/dropout.py index f7e783176..08587e910 100644 --- a/imperative/python/megengine/module/dropout.py +++ b/imperative/python/megengine/module/dropout.py @@ -20,8 +20,8 @@ class Dropout(Module): :param drop_prob: The probability to drop (set to zero) each single element """ - def __init__(self, drop_prob=0.0): - super().__init__() + def __init__(self, drop_prob=0.0, **kwargs): + super().__init__(**kwargs) self.drop_prob = drop_prob def forward(self, inputs): diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index 9713557cc..d600d3de7 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -72,8 +72,8 @@ class Elemwise(Module): * "NOT": bool unary: ~x """ - def __init__(self, method): - super().__init__() + def __init__(self, method, **kwargs): + super().__init__(**kwargs) self.method = method def forward(self, *inps): diff --git a/imperative/python/megengine/module/embedding.py b/imperative/python/megengine/module/embedding.py index d5f93afd5..1f12f2494 100644 --- a/imperative/python/megengine/module/embedding.py +++ b/imperative/python/megengine/module/embedding.py @@ -64,8 +64,9 @@ class Embedding(Module): norm_type: Optional[float] = None, initial_weight: Parameter = None, freeze: bool = False, + **kwargs ): - super().__init__() + super().__init__(**kwargs) if padding_idx is not None: raise ValueError("Not support padding index now.") if max_norm is not None or norm_type is not None: diff --git a/imperative/python/megengine/module/external.py b/imperative/python/megengine/module/external.py index ae98c8c4f..c28595d94 100644 --- a/imperative/python/megengine/module/external.py +++ b/imperative/python/megengine/module/external.py @@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module): See :func:`~.tensorrt_runtime_opr` for more details. """ - def __init__( - self, data, - ): - super(TensorrtRuntimeSubgraph, self).__init__() + def __init__(self, data, **kwargs): + super(TensorrtRuntimeSubgraph, self).__init__(**kwargs) self._data = data @property diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index 8c1a68490..899f51eca 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -20,8 +20,8 @@ class GroupNorm(Module): Reference: https://arxiv.org/pdf/1803.08494.pdf. """ - def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): - super().__init__() + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): + super().__init__(**kwargs) assert num_channels % num_groups == 0 self.num_groups = num_groups self.num_channels = num_channels @@ -70,8 +70,8 @@ class InstanceNorm(Module): Note that InstanceNorm equals using GroupNome with num_groups=num_channels. """ - def __init__(self, num_channels, eps=1e-05, affine=True): - super().__init__() + def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): + super().__init__(**kwargs) self.num_channels = num_channels self.eps = eps self.affine = affine @@ -114,8 +114,8 @@ class LayerNorm(Module): Note that LayerNorm equals using GroupNorm with num_groups=1. """ - def __init__(self, num_channels, eps=1e-05, affine=True): - super().__init__() + def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): + super().__init__(**kwargs) self.num_channels = num_channels self.eps = eps self.affine = affine diff --git a/imperative/python/megengine/module/pooling.py b/imperative/python/megengine/module/pooling.py index 56e0948cf..c8ff81bea 100644 --- a/imperative/python/megengine/module/pooling.py +++ b/imperative/python/megengine/module/pooling.py @@ -19,8 +19,9 @@ class _PoolNd(Module): kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = None, padding: Union[int, Tuple[int, int]] = 0, + **kwargs ): - super(_PoolNd, self).__init__() + super(_PoolNd, self).__init__(**kwargs) self.kernel_size = kernel_size self.stride = stride or kernel_size self.padding = padding diff --git a/imperative/python/megengine/module/sequential.py b/imperative/python/megengine/module/sequential.py index 1cce1d4c8..e484110c3 100644 --- a/imperative/python/megengine/module/sequential.py +++ b/imperative/python/megengine/module/sequential.py @@ -46,8 +46,8 @@ class Sequential(Module): pred1 = net1(data) """ - def __init__(self, *args): - super().__init__() + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) self.layer_keys = [] if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): -- GitLab