Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
27638461
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
27638461
编写于
3月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): support list/dict/tuple in module __repr__
GitOrigin-RevId: b70193fd79576e37b75372fd084f82f038816f85
上级
07826f5e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
44 addition
and
23 deletion
+44
-23
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+15
-7
imperative/python/test/unit/module/test_module.py
imperative/python/test/unit/module/test_module.py
+29
-16
未找到文件。
imperative/python/megengine/module/module.py
浏览文件 @
27638461
...
@@ -73,6 +73,7 @@ class Module(metaclass=ABCMeta):
...
@@ -73,6 +73,7 @@ class Module(metaclass=ABCMeta):
:param name: module's name, can be initialized by the ``kwargs`` parameter
:param name: module's name, can be initialized by the ``kwargs`` parameter
of child class.
of child class.
"""
"""
self
.
_modules
=
[]
if
name
is
not
None
:
if
name
is
not
None
:
assert
(
assert
(
...
@@ -89,8 +90,6 @@ class Module(metaclass=ABCMeta):
...
@@ -89,8 +90,6 @@ class Module(metaclass=ABCMeta):
self
.
_forward_pre_hooks
=
OrderedDict
()
self
.
_forward_pre_hooks
=
OrderedDict
()
self
.
_forward_hooks
=
OrderedDict
()
self
.
_forward_hooks
=
OrderedDict
()
self
.
_modules
=
[]
# used for profiler and automatic naming
# used for profiler and automatic naming
self
.
_name
=
"{anonymous}"
self
.
_name
=
"{anonymous}"
...
@@ -595,7 +594,9 @@ class Module(metaclass=ABCMeta):
...
@@ -595,7 +594,9 @@ class Module(metaclass=ABCMeta):
return
value
return
value
def
__setattr__
(
self
,
name
:
str
,
value
):
def
__setattr__
(
self
,
name
:
str
,
value
):
if
_is_module
(
value
):
if
_is_module
(
value
)
or
(
isinstance
(
value
,
(
list
,
tuple
,
dict
))
and
name
!=
"_modules"
):
modules
=
self
.
__dict__
.
get
(
"_modules"
)
modules
=
self
.
__dict__
.
get
(
"_modules"
)
if
modules
is
None
:
if
modules
is
None
:
raise
AttributeError
(
raise
AttributeError
(
...
@@ -633,10 +634,17 @@ class Module(metaclass=ABCMeta):
...
@@ -633,10 +634,17 @@ class Module(metaclass=ABCMeta):
extra_repr
=
self
.
_module_info_string
()
extra_repr
=
self
.
_module_info_string
()
if
extra_repr
:
if
extra_repr
:
extra_lines
=
extra_repr
.
split
(
"
\n
"
)
extra_lines
=
extra_repr
.
split
(
"
\n
"
)
child_lines
=
[
child_lines
=
[]
"("
+
name
+
"): "
+
add_indent
(
repr
(
self
.
__dict__
[
name
]),
2
)
for
name
in
self
.
_modules
:
for
name
in
self
.
_modules
if
_is_module
(
self
.
__dict__
[
name
]):
]
child_lines
.
append
(
"("
+
name
+
"): "
+
add_indent
(
repr
(
self
.
__dict__
[
name
]),
2
)
)
else
:
for
k
,
v
in
_expand_structure
(
name
,
self
.
__dict__
[
name
]):
if
_is_module
(
v
):
child_lines
.
append
(
"("
+
k
+
"): "
+
add_indent
(
repr
(
v
),
2
))
lines
=
extra_lines
+
child_lines
lines
=
extra_lines
+
child_lines
main_str
=
self
.
__class__
.
__name__
+
"("
main_str
=
self
.
__class__
.
__name__
+
"("
if
lines
:
if
lines
:
...
...
imperative/python/test/unit/module/test_module.py
浏览文件 @
27638461
...
@@ -656,15 +656,23 @@ def test_repr_basic():
...
@@ -656,15 +656,23 @@ def test_repr_basic():
class
ConvModel
(
Module
):
class
ConvModel
(
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
Conv2d
(
3
,
128
,
3
,
stride
=
2
,
bias
=
False
)
self
.
conv1
=
Conv2d
(
3
,
128
,
3
,
padding
=
1
,
bias
=
False
)
self
.
conv2
=
Conv2d
(
3
,
128
,
3
,
padding
=
1
,
bias
=
False
)
self
.
conv2
=
Conv2d
(
3
,
128
,
3
,
dilation
=
2
,
bias
=
False
)
self
.
conv3
=
Conv2d
(
3
,
128
,
3
,
dilation
=
2
,
bias
=
False
)
self
.
bn1
=
BatchNorm1d
(
128
)
self
.
bn1
=
BatchNorm2d
(
128
)
self
.
bn2
=
BatchNorm2d
(
128
)
self
.
bn2
=
BatchNorm1d
(
128
)
self
.
dropout
=
Dropout
(
drop_prob
=
0.1
)
self
.
softmax
=
Softmax
(
axis
=
100
)
self
.
pooling
=
MaxPool2d
(
kernel_size
=
2
,
padding
=
0
)
self
.
pooling
=
MaxPool2d
(
kernel_size
=
2
,
padding
=
0
)
self
.
submodule1
=
Sequential
(
Dropout
(
drop_prob
=
0.1
),
Softmax
(
axis
=
100
),)
modules
=
OrderedDict
()
modules
[
"depthwise"
]
=
Conv2d
(
256
,
256
,
3
,
1
,
0
,
groups
=
256
,
bias
=
False
,)
modules
[
"pointwise"
]
=
Conv2d
(
256
,
256
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
)
self
.
submodule1
=
Sequential
(
modules
)
self
.
list1
=
[
Dropout
(
drop_prob
=
0.1
),
[
Softmax
(
axis
=
100
)]]
self
.
tuple1
=
(
Dropout
(
drop_prob
=
0.1
),
(
Softmax
(
axis
=
100
),
Dropout
(
drop_prob
=
0.2
)),
)
self
.
dict1
=
{
"Dropout"
:
Dropout
(
drop_prob
=
0.1
)}
self
.
fc1
=
Linear
(
512
,
1024
)
self
.
fc1
=
Linear
(
512
,
1024
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
...
@@ -672,16 +680,21 @@ def test_repr_basic():
...
@@ -672,16 +680,21 @@ def test_repr_basic():
ground_truth
=
(
ground_truth
=
(
"ConvModel(
\n
"
"ConvModel(
\n
"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
\n
"
" (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)
\n
"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)
\n
"
" (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)
\n
"
" (conv3): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)
\n
"
" (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
\n
"
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
\n
"
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
\n
"
" (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
\n
"
" (dropout): Dropout(drop_prob=0.1)
\n
(softmax): Softmax(axis=100)
\n
"
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)
\n
"
" (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)
\n
"
" (submodule1): Sequential(
\n
"
" (submodule1): Sequential(
\n
"
" (0): Dropout(drop_prob=0.1)
\n
"
" (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)
\n
"
" (1): Softmax(axis=100)
\n
)
\n
"
" (pointwise): Conv2d(256, 256, kernel_size=(1, 1))
\n
"
" )
\n
"
" (list1.0): Dropout(drop_prob=0.1)
\n
"
" (list1.1.0): Softmax(axis=100)
\n
"
" (tuple1.0): Dropout(drop_prob=0.1)
\n
"
" (tuple1.1.0): Softmax(axis=100)
\n
"
" (tuple1.1.1): Dropout(drop_prob=0.2)
\n
"
" (dict1.Dropout): Dropout(drop_prob=0.1)
\n
"
" (fc1): Linear(in_features=512, out_features=1024, bias=True)
\n
"
" (fc1): Linear(in_features=512, out_features=1024, bias=True)
\n
"
")"
")"
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录