diff --git a/imperative/python/megengine/module/activation.py b/imperative/python/megengine/module/activation.py index 817fc5cb8b22c1b8e373872a15c4cb1400b37e48..be232b0a7206325c8543282d3b862db3e418cd08 100644 --- a/imperative/python/megengine/module/activation.py +++ b/imperative/python/megengine/module/activation.py @@ -55,6 +55,9 @@ class Softmax(Module): def forward(self, inputs): return softmax(inputs, self.axis) + def _module_info_string(self) -> str: + return "axis={axis}".format(axis=self.axis) + class Sigmoid(Module): r""" diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 8651546b72e61e4388d2bb31935b39368220d0b4..c9abfa9a9a5b76ac0175d7323c3ed803fffa2a40 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -113,6 +113,13 @@ class _BatchNorm(Module): return output + def _module_info_string(self) -> str: + s = ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}" + ) + return s.format(**self.__dict__) + class SyncBatchNorm(_BatchNorm): r""" diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 8d07505c53e84c7f5813abb3468bbee3a7fac78d..c287dff8addf71c7f9ab0eb145286e5790c78533 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -70,6 +70,21 @@ class _ConvNd(Module): def _infer_bias_shape(self): pass + def _module_info_string(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" + + if self.stride != (1,) * len(self.stride): + s += ", stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + return s.format(**self.__dict__) + class Conv2d(_ConvNd): r"""Applies a 2D convolution over an input tensor. diff --git a/imperative/python/megengine/module/dropout.py b/imperative/python/megengine/module/dropout.py index 0aac97129f531333ba55752a41e1512421dc545c..afd057d08153a74ba8e37d14582a162174e80c21 100644 --- a/imperative/python/megengine/module/dropout.py +++ b/imperative/python/megengine/module/dropout.py @@ -28,3 +28,6 @@ class Dropout(Module): return dropout(inputs, self.drop_prob, training=True) else: return inputs + + def _module_info_string(self) -> str: + return "drop_prob={drop_prob}".format(drop_prob=self.drop_prob) diff --git a/imperative/python/megengine/module/linear.py b/imperative/python/megengine/module/linear.py index ba5c81aac5d3ff15772dce8b6231bfdab829ec69..06a4f91c643adedc0f36458803bcec8fd533a191 100644 --- a/imperative/python/megengine/module/linear.py +++ b/imperative/python/megengine/module/linear.py @@ -78,3 +78,8 @@ class Linear(Module): def forward(self, x): return self._calc_linear(x, self.weight, self.bias) + + def _module_info_string(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features, self.out_features, self.bias is not None + ) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index bf87be9dbbf6177805882f6d5a53217e04f84cc0..10dba3515c4314ea3185fe01823bb44ad52c9290 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -69,6 +69,8 @@ class Module(metaclass=ABCMeta): self._forward_pre_hooks = OrderedDict() self._forward_hooks = OrderedDict() + self._modules = [] + @abstractmethod def forward(self, inputs): pass @@ -518,3 +520,57 @@ class Module(metaclass=ABCMeta): loaded.append(k) return set(loaded), set(skipped) + + def __setattr__(self, name: str, value): + if _is_module(value): + modules = self.__dict__.get("_modules") + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call" + ) + if name not in self.__dict__: + modules.append(name) + super().__setattr__(name, value) + + def __delattr__(self, name: str): + if name in self.__dict__ and _is_module(self.__dict__[name]): + modules = self.__dict__.get("_modules") + modules.remove(name) + super().__delattr__(name) + + def _module_info_string(self) -> str: + r"""Set the extra representation of the module. + """ + return "" + + def __repr__(self): + def add_indent(repr_str, num_spaces): + s = repr_str.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return repr_str + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + extra_lines = [] + extra_repr = self._module_info_string() + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [ + "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) + for name in self._modules + ] + lines = extra_lines + child_lines + main_str = self.__class__.__name__ + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str diff --git a/imperative/python/megengine/module/pooling.py b/imperative/python/megengine/module/pooling.py index b5c10a0958fc3964b1263ff994334f23a14edfe5..10dfc1400b2f1b972b6c3bdd0dadd6fd428225d2 100644 --- a/imperative/python/megengine/module/pooling.py +++ b/imperative/python/megengine/module/pooling.py @@ -29,6 +29,11 @@ class _PoolNd(Module): def forward(self, inp): pass + def _module_info_string(self) -> str: + return "kernel_size={kernel_size}, stride={stride}, padding={padding}".format( + **self.__dict__ + ) + class MaxPool2d(_PoolNd): r"""Applies a 2D max pooling over an input. diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index d4a5f30479aa85d10d73e17755cce18bcb1d15ed..d3a492414f8a21bbc9373b397b58e081fb585b2b 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -21,9 +21,12 @@ from megengine.module import ( BatchNorm1d, BatchNorm2d, Conv2d, + Dropout, Linear, + MaxPool2d, Module, Sequential, + Softmax, ) from megengine.quantization.quantize import quantize, quantize_qat from megengine.test import assertTensorClose @@ -609,3 +612,111 @@ def test_load_quantized(): assertTensorClose( pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6 ) + + +def test_repr_basic(): + # test whether __repr__ can output correct information + class ConvModel(Module): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False) + self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) + self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False) + self.bn1 = BatchNorm2d(128) + self.bn2 = BatchNorm1d(128) + self.dropout = Dropout(drop_prob=0.1) + self.softmax = Softmax(axis=100) + self.pooling = MaxPool2d(kernel_size=2, padding=0) + self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),) + self.fc1 = Linear(512, 1024) + + def forward(self, inputs): + pass + + ground_truth = ( + "ConvModel(\n" + " (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n" + " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" + " (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" + " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" + " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" + " (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n" + " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" + " (submodule1): Sequential(\n" + " (0): Dropout(drop_prob=0.1)\n" + " (1): Softmax(axis=100)\n )\n" + " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" + ")" + ) + net = ConvModel() + output = net.__repr__() + assert output == ground_truth + + +def test_repr_module_reassign(): + # test whether __repr__ can deal with module reassign + class ConvModel1(Module): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(3, 128, 3, bias=False) + self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) + self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False) + + def forward(self, inputs): + pass + + ground_truth = ( + "ConvModel1(\n" + " (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" + " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" + ")" + ) + net = ConvModel1() + output = net.__repr__() + assert output == ground_truth + + +def test_repr_module_rereference(): + # test whether __repr__ can deal with module re-reference + class ConvModel2(Module): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(3, 128, 3, bias=False) + self.conv2 = self.conv1 + self.conv3 = self.conv1 + + def forward(self, inputs): + pass + + ground_truth = ( + "ConvModel2(\n" + " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" + " (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" + " (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" + ")" + ) + net = ConvModel2() + output = net.__repr__() + assert output == ground_truth + + +def test_repr_module_delete(): + # test whether __repr__ can deal with module delete + class ConvModel3(Module): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(3, 128, 3, bias=False) + self.softmax = Softmax(100) + + def forward(self, inputs): + pass + + ground_truth = ( + "ConvModel3(\n" + " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n" + ")" + ) + net = ConvModel3() + del net.softmax + output = net.__repr__() + assert output == ground_truth