From 2763846168b5bce0249d9fbcf2ba0c23756a9fb4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Mar 2021 14:36:28 +0800 Subject: [PATCH] fix(mge/module): support list/dict/tuple in module __repr__ GitOrigin-RevId: b70193fd79576e37b75372fd084f82f038816f85 --- imperative/python/megengine/module/module.py | 22 ++++++--- .../python/test/unit/module/test_module.py | 45 ++++++++++++------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 5293ad297..9fa969b9d 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -73,6 +73,7 @@ class Module(metaclass=ABCMeta): :param name: module's name, can be initialized by the ``kwargs`` parameter of child class. """ + self._modules = [] if name is not None: assert ( @@ -89,8 +90,6 @@ class Module(metaclass=ABCMeta): self._forward_pre_hooks = OrderedDict() self._forward_hooks = OrderedDict() - self._modules = [] - # used for profiler and automatic naming self._name = "{anonymous}" @@ -595,7 +594,9 @@ class Module(metaclass=ABCMeta): return value def __setattr__(self, name: str, value): - if _is_module(value): + if _is_module(value) or ( + isinstance(value, (list, tuple, dict)) and name != "_modules" + ): modules = self.__dict__.get("_modules") if modules is None: raise AttributeError( @@ -633,10 +634,17 @@ class Module(metaclass=ABCMeta): extra_repr = self._module_info_string() if extra_repr: extra_lines = extra_repr.split("\n") - child_lines = [ - "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) - for name in self._modules - ] + child_lines = [] + for name in self._modules: + if _is_module(self.__dict__[name]): + child_lines.append( + "(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) + ) + else: + for k, v in _expand_structure(name, self.__dict__[name]): + if _is_module(v): + child_lines.append("(" + k + "): " + add_indent(repr(v), 2)) + lines = extra_lines + child_lines main_str = self.__class__.__name__ + "(" if lines: diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index bff7d9ed6..b28219be8 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -656,15 +656,23 @@ def test_repr_basic(): class ConvModel(Module): def __init__(self): super().__init__() - self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False) - self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) - self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False) - self.bn1 = BatchNorm2d(128) - self.bn2 = BatchNorm1d(128) - self.dropout = Dropout(drop_prob=0.1) - self.softmax = Softmax(axis=100) + self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False) + self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False) + self.bn1 = BatchNorm1d(128) + self.bn2 = BatchNorm2d(128) self.pooling = MaxPool2d(kernel_size=2, padding=0) - self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),) + modules = OrderedDict() + modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,) + modules["pointwise"] = Conv2d( + 256, 256, kernel_size=1, stride=1, padding=0, bias=True, + ) + self.submodule1 = Sequential(modules) + self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]] + self.tuple1 = ( + Dropout(drop_prob=0.1), + (Softmax(axis=100), Dropout(drop_prob=0.2)), + ) + self.dict1 = {"Dropout": Dropout(drop_prob=0.1)} self.fc1 = Linear(512, 1024) def forward(self, inputs): @@ -672,16 +680,21 @@ def test_repr_basic(): ground_truth = ( "ConvModel(\n" - " (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n" - " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" - " (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" - " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" - " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" - " (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n" + " (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" + " (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" + " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" " (submodule1): Sequential(\n" - " (0): Dropout(drop_prob=0.1)\n" - " (1): Softmax(axis=100)\n )\n" + " (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n" + " (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n" + " )\n" + " (list1.0): Dropout(drop_prob=0.1)\n" + " (list1.1.0): Softmax(axis=100)\n" + " (tuple1.0): Dropout(drop_prob=0.1)\n" + " (tuple1.1.0): Softmax(axis=100)\n" + " (tuple1.1.1): Dropout(drop_prob=0.2)\n" + " (dict1.Dropout): Dropout(drop_prob=0.1)\n" " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" ")" ) -- GitLab