提交 27638461 编写于 作者: M Megvii Engine Team

fix(mge/module): support list/dict/tuple in module __repr__

GitOrigin-RevId: b70193fd79576e37b75372fd084f82f038816f85
上级 07826f5e
...@@ -73,6 +73,7 @@ class Module(metaclass=ABCMeta): ...@@ -73,6 +73,7 @@ class Module(metaclass=ABCMeta):
:param name: module's name, can be initialized by the ``kwargs`` parameter :param name: module's name, can be initialized by the ``kwargs`` parameter
of child class. of child class.
""" """
self._modules = []
if name is not None: if name is not None:
assert ( assert (
...@@ -89,8 +90,6 @@ class Module(metaclass=ABCMeta): ...@@ -89,8 +90,6 @@ class Module(metaclass=ABCMeta):
self._forward_pre_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict()
self._forward_hooks = OrderedDict() self._forward_hooks = OrderedDict()
self._modules = []
# used for profiler and automatic naming # used for profiler and automatic naming
self._name = "{anonymous}" self._name = "{anonymous}"
...@@ -595,7 +594,9 @@ class Module(metaclass=ABCMeta): ...@@ -595,7 +594,9 @@ class Module(metaclass=ABCMeta):
return value return value
def __setattr__(self, name: str, value): def __setattr__(self, name: str, value):
if _is_module(value): if _is_module(value) or (
isinstance(value, (list, tuple, dict)) and name != "_modules"
):
modules = self.__dict__.get("_modules") modules = self.__dict__.get("_modules")
if modules is None: if modules is None:
raise AttributeError( raise AttributeError(
...@@ -633,10 +634,17 @@ class Module(metaclass=ABCMeta): ...@@ -633,10 +634,17 @@ class Module(metaclass=ABCMeta):
extra_repr = self._module_info_string() extra_repr = self._module_info_string()
if extra_repr: if extra_repr:
extra_lines = extra_repr.split("\n") extra_lines = extra_repr.split("\n")
child_lines = [ child_lines = []
"(" + name + "): " + add_indent(repr(self.__dict__[name]), 2) for name in self._modules:
for name in self._modules if _is_module(self.__dict__[name]):
] child_lines.append(
"(" + name + "): " + add_indent(repr(self.__dict__[name]), 2)
)
else:
for k, v in _expand_structure(name, self.__dict__[name]):
if _is_module(v):
child_lines.append("(" + k + "): " + add_indent(repr(v), 2))
lines = extra_lines + child_lines lines = extra_lines + child_lines
main_str = self.__class__.__name__ + "(" main_str = self.__class__.__name__ + "("
if lines: if lines:
......
...@@ -656,15 +656,23 @@ def test_repr_basic(): ...@@ -656,15 +656,23 @@ def test_repr_basic():
class ConvModel(Module): class ConvModel(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False) self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False)
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False) self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False)
self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False) self.bn1 = BatchNorm1d(128)
self.bn1 = BatchNorm2d(128) self.bn2 = BatchNorm2d(128)
self.bn2 = BatchNorm1d(128)
self.dropout = Dropout(drop_prob=0.1)
self.softmax = Softmax(axis=100)
self.pooling = MaxPool2d(kernel_size=2, padding=0) self.pooling = MaxPool2d(kernel_size=2, padding=0)
self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),) modules = OrderedDict()
modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,)
modules["pointwise"] = Conv2d(
256, 256, kernel_size=1, stride=1, padding=0, bias=True,
)
self.submodule1 = Sequential(modules)
self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]]
self.tuple1 = (
Dropout(drop_prob=0.1),
(Softmax(axis=100), Dropout(drop_prob=0.2)),
)
self.dict1 = {"Dropout": Dropout(drop_prob=0.1)}
self.fc1 = Linear(512, 1024) self.fc1 = Linear(512, 1024)
def forward(self, inputs): def forward(self, inputs):
...@@ -672,16 +680,21 @@ def test_repr_basic(): ...@@ -672,16 +680,21 @@ def test_repr_basic():
ground_truth = ( ground_truth = (
"ConvModel(\n" "ConvModel(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n" " (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n" " (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n" " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n" " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n"
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n" " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n"
" (submodule1): Sequential(\n" " (submodule1): Sequential(\n"
" (0): Dropout(drop_prob=0.1)\n" " (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n"
" (1): Softmax(axis=100)\n )\n" " (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n"
" )\n"
" (list1.0): Dropout(drop_prob=0.1)\n"
" (list1.1.0): Softmax(axis=100)\n"
" (tuple1.0): Dropout(drop_prob=0.1)\n"
" (tuple1.1.0): Softmax(axis=100)\n"
" (tuple1.1.1): Dropout(drop_prob=0.2)\n"
" (dict1.Dropout): Dropout(drop_prob=0.1)\n"
" (fc1): Linear(in_features=512, out_features=1024, bias=True)\n" " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n"
")" ")"
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册