Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
82692159
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
82692159
编写于
7月 20, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
8月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): fix named_children of Sequential
GitOrigin-RevId: d3220fb361f018042f5c5a8d085e037397e7ecef
上级
eed54081
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
14 deletion
+36
-14
python_module/megengine/module/sequential.py
python_module/megengine/module/sequential.py
+15
-13
python_module/test/unit/module/test_module.py
python_module/test/unit/module/test_module.py
+21
-1
未找到文件。
python_module/megengine/module/sequential.py
浏览文件 @
82692159
...
...
@@ -19,7 +19,7 @@ class Sequential(Module):
To make it easier to understand, here is a small example:
.. testcode::
from collections import OrderedDict
import numpy as np
import megengine.nn as nn
import megengine.nn.functional as F
...
...
@@ -29,34 +29,35 @@ class Sequential(Module):
label = nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,))
data = data.reshape(batch_size, -1)
net = nn.Sequential(
net0 = nn.Sequential(
nn.Linear(28 * 28, 320),
nn.Linear(320, 500),
nn.Linear(500, 320),
nn.Linear(320, 10)
)
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label
)
pred0 = net0(data
)
modules = OrderedDict()
modules["fc0"] = nn.Linear(28 * 28, 320)
modules["fc1"] = nn.Linear(320, 10)
net1 = nn.Sequential(modules)
pred1 = net1(data)
"""
def
__init__
(
self
,
*
args
):
super
().
__init__
()
self
.
layer_keys
=
[]
self
.
layer_values
=
[]
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
OrderedDict
):
for
key
,
module
in
args
[
0
].
items
():
# self.add_module(key, module)
setattr
(
self
,
key
,
module
)
self
.
layer_keys
.
append
(
key
)
self
.
layer_values
.
append
(
module
)
else
:
for
idx
,
module
in
enumerate
(
args
):
# self.add_module(str(idx), module)
setattr
(
self
,
str
(
idx
),
module
)
self
.
layer_keys
.
append
(
str
(
idx
))
self
.
layer_values
.
append
(
module
)
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
slice
):
...
...
@@ -64,11 +65,10 @@ class Sequential(Module):
OrderedDict
(
zip
(
self
.
layer_keys
[
idx
],
self
.
layer_values
[
idx
]))
)
else
:
return
self
.
layer_values
[
idx
]
return
getattr
(
self
,
self
.
layer_keys
[
idx
])
def
__setitem__
(
self
,
idx
,
module
):
key
=
self
.
layer_keys
[
idx
]
self
.
layer_values
[
idx
]
=
module
return
setattr
(
self
,
key
,
module
)
def
__delitem__
(
self
,
idx
):
...
...
@@ -76,11 +76,9 @@ class Sequential(Module):
for
key
in
self
.
layer_keys
[
idx
]:
delattr
(
self
,
key
)
del
self
.
layer_keys
[
idx
]
del
self
.
layer_values
[
idx
]
else
:
delattr
(
self
,
self
.
layer_keys
[
idx
])
del
self
.
layer_keys
[
idx
]
del
self
.
layer_values
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
layer_keys
)
...
...
@@ -88,6 +86,10 @@ class Sequential(Module):
def
__iter__
(
self
):
return
iter
(
self
.
layer_values
)
@
property
def
layer_values
(
self
):
return
[
getattr
(
self
,
key
)
for
key
in
self
.
layer_keys
]
def
forward
(
self
,
inp
):
for
layer
in
self
.
layer_values
:
inp
=
layer
(
inp
)
...
...
python_module/test/unit/module/test_module.py
浏览文件 @
82692159
...
...
@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
tempfile
from
collections
import
OrderedDict
from
io
import
BytesIO
import
numpy
as
np
...
...
@@ -16,7 +17,14 @@ from helpers import MLP
import
megengine
as
mge
import
megengine._internal
as
mgb
from
megengine.core
import
Buffer
,
Parameter
,
Tensor
,
tensor
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
Module
,
Sequential
from
megengine.module
import
(
BatchNorm1d
,
BatchNorm2d
,
Conv2d
,
Linear
,
Module
,
Sequential
,
)
from
megengine.quantization.quantize
import
quantize
,
quantize_qat
from
megengine.test
import
assertTensorClose
...
...
@@ -238,6 +246,18 @@ def test_module_api_with_sequential():
]
def
test_sequential_named_children
():
modules
=
OrderedDict
()
modules
[
"name0"
]
=
Linear
(
20
,
10
)
modules
[
"name1"
]
=
Linear
(
10
,
5
)
modules
[
"name2"
]
=
Linear
(
5
,
1
)
m
=
Sequential
(
modules
)
l
=
list
(
m
.
named_children
())
assert
l
[
0
][
0
]
==
"name0"
assert
l
[
1
][
0
]
==
"name1"
assert
l
[
2
][
0
]
==
"name2"
def
test_state_dict
():
data_shape
=
(
2
,
28
)
data
=
tensor
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录