提交 66f2dbd7 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add module __repr__ for imperative

GitOrigin-RevId: ac13cc46599a39c398d129bdc5328f59bdbbb359
上级 b8ddca4c
......@@ -55,6 +55,9 @@ class Softmax(Module):
def forward(self, inputs):
return softmax(inputs, self.axis)
def _module_info_string(self) -> str:
return "axis={axis}".format(axis=self.axis)
class Sigmoid(Module):
r"""
......
......@@ -113,6 +113,13 @@ class _BatchNorm(Module):
return output
def _module_info_string(self) -> str:
s = (
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
"track_running_stats={track_running_stats}"
)
return s.format(**self.__dict__)
class SyncBatchNorm(_BatchNorm):
r"""
......
......@@ -70,6 +70,21 @@ class _ConvNd(Module):
def _infer_bias_shape(self):
pass
def _module_info_string(self):
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
if self.stride != (1,) * len(self.stride):
s += ", stride={stride}"
if self.padding != (0,) * len(self.padding):
s += ", padding={padding}"
if self.dilation != (1,) * len(self.dilation):
s += ", dilation={dilation}"
if self.groups != 1:
s += ", groups={groups}"
if self.bias is None:
s += ", bias=False"
return s.format(**self.__dict__)
class Conv2d(_ConvNd):
r"""Applies a 2D convolution over an input tensor.
......
......@@ -28,3 +28,6 @@ class Dropout(Module):
return dropout(inputs, self.drop_prob, training=True)
else:
return inputs
def _module_info_string(self) -> str:
return "drop_prob={drop_prob}".format(drop_prob=self.drop_prob)
......@@ -78,3 +78,8 @@ class Linear(Module):
def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)
def _module_info_string(self) -> str:
return "in_features={}, out_features={}, bias={}".format(
self.in_features, self.out_features, self.bias is not None
)
......@@ -69,6 +69,8 @@ class Module(metaclass=ABCMeta):
self._forward_pre_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._modules = []
@abstractmethod
def forward(self, inputs):
pass
......@@ -518,3 +520,57 @@ class Module(metaclass=ABCMeta):
loaded.append(k)
return set(loaded), set(skipped)
def __setattr__(self, name: str, value):
if _is_module(value):
modules = self.__dict__.get("_modules")
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call"
)
if name not in self.__dict__:
modules.append(name)
super().__setattr__(name, value)
def __delattr__(self, name: str):
if name in self.__dict__ and _is_module(self.__dict__[name]):
modules = self.__dict__.get("_modules")
modules.remove(name)
super().__delattr__(name)
def _module_info_string(self) -> str:
r"""Set the extra representation of the module.
"""
return ""
def __repr__(self):
def add_indent(repr_str, num_spaces):
s = repr_str.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return repr_str
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
extra_lines = []
extra_repr = self._module_info_string()
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = [
"(" + name + "): " + add_indent(repr(self.__dict__[name]), 2)
for name in self._modules
]
lines = extra_lines + child_lines
main_str = self.__class__.__name__ + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
......@@ -29,6 +29,11 @@ class _PoolNd(Module):
def forward(self, inp):
pass
def _module_info_string(self) -> str:
return "kernel_size={kernel_size}, stride={stride}, padding={padding}".format(
**self.__dict__
)
class MaxPool2d(_PoolNd):
r"""Applies a 2D max pooling over an input.
......
......@@ -21,9 +21,12 @@ from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv2d,
Dropout,
Linear,
MaxPool2d,
Module,
Sequential,
Softmax,
)
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.test import assertTensorClose
......@@ -609,3 +612,111 @@ def test_load_quantized():
assertTensorClose(
pred0.astype("float32").numpy(), pred1.astype("float32").numpy(), max_err=5e-6
)
def test_repr_basic():
# test whether __repr__ can output correct information
class ConvModel(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, stride=2, bias=False)
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False)
self.conv3 = Conv2d(3, 128, 3, dilation=2, bias=False)
self.bn1 = BatchNorm2d(128)
self.bn2 = BatchNorm1d(128)
self.dropout = Dropout(drop_prob=0.1)
self.softmax = Softmax(axis=100)
self.pooling = MaxPool2d(kernel_size=2, padding=0)
self.submodule1 = Sequential(Dropout(drop_prob=0.1), Softmax(axis=100),)
self.fc1 = Linear(512, 1024)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
" (dropout): Dropout(drop_prob=0.1)\n (softmax): Softmax(axis=100)\n"
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n"
" (submodule1): Sequential(\n"
" (0): Dropout(drop_prob=0.1)\n"
" (1): Softmax(axis=100)\n )\n"
" (fc1): Linear(in_features=512, out_features=1024, bias=True)\n"
")"
)
net = ConvModel()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_reassign():
# test whether __repr__ can deal with module reassign
class ConvModel1(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False)
self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel1(\n"
" (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
")"
)
net = ConvModel1()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_rereference():
# test whether __repr__ can deal with module re-reference
class ConvModel2(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.conv2 = self.conv1
self.conv3 = self.conv1
def forward(self, inputs):
pass
ground_truth = (
"ConvModel2(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
")"
)
net = ConvModel2()
output = net.__repr__()
assert output == ground_truth
def test_repr_module_delete():
# test whether __repr__ can deal with module delete
class ConvModel3(Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 128, 3, bias=False)
self.softmax = Softmax(100)
def forward(self, inputs):
pass
ground_truth = (
"ConvModel3(\n"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
")"
)
net = ConvModel3()
del net.softmax
output = net.__repr__()
assert output == ground_truth
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册