Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
a26546593
dive-into-dl-pytorch
提交
9f8cd357
D
dive-into-dl-pytorch
项目概览
a26546593
/
dive-into-dl-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
dive-into-dl-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9f8cd357
编写于
11月 10, 2019
作者:
S
ShusenTang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more info about ModuleList
上级
aa75893a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
152 addition
and
20 deletion
+152
-20
code/chapter04_DL_computation/4.1_model-construction.ipynb
code/chapter04_DL_computation/4.1_model-construction.ipynb
+99
-20
docs/chapter04_DL_computation/4.1_model-construction.md
docs/chapter04_DL_computation/4.1_model-construction.md
+53
-0
未找到文件。
code/chapter04_DL_computation/4.1_model-construction.ipynb
浏览文件 @
9f8cd357
...
...
@@ -16,7 +16,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"
0.4.1
\n"
"
1.2.0
\n"
]
}
],
...
...
@@ -78,10 +78,10 @@
{
"data": {
"text/plain": [
"tensor([[ 0.
1351, -0.0034, 0.0948, -0.1652, 0.1512, 0.0887, -0.0032, 0.0692
,\n",
" 0.
0942, 0.0956
],\n",
" [ 0.1
624, -0.0383, 0.1557, -0.0735, 0.1931, 0.1699, -0.0067, 0.0353
,\n",
" 0.1
712, 0.1568]], grad_fn=<Th
AddmmBackward>)"
"tensor([[ 0.
0234, -0.2646, -0.1168, -0.2127, 0.0884, -0.0456, 0.0811, 0.0297
,\n",
" 0.
2032, 0.1364
],\n",
" [ 0.1
479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086, 0.0902, -0.1017
,\n",
" 0.1
504, 0.1144]], grad_fn=<
AddmmBackward>)"
]
},
"execution_count": 3,
...
...
@@ -107,7 +107,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MySequential(nn.Module):\n",
...
...
@@ -146,10 +148,10 @@
{
"data": {
"text/plain": [
"tensor([[ 0.1
883, -0.1269, -0.1886, 0.0638, -0.1004, -0.0600, 0.0760, -0.17
88,\n",
"
-0.1844, -0.2131
],\n",
" [ 0.
1319, -0.0490, -0.1365, 0.0133, -0.0483, -0.0861, 0.0369, -0.0830
,\n",
"
-0.0462, -0.2066]], grad_fn=<Th
AddmmBackward>)"
"tensor([[ 0.1
273, 0.1642, -0.1060, 0.1401, 0.0609, -0.0199, -0.0140, -0.05
88,\n",
"
0.1765, -0.1296
],\n",
" [ 0.
0267, 0.1670, -0.0626, 0.0744, 0.0574, 0.0413, 0.1313, -0.1479
,\n",
"
0.0932, -0.0615]], grad_fn=<
AddmmBackward>)"
]
},
"execution_count": 5,
...
...
@@ -199,6 +201,74 @@
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MyModule(nn.Module):\n",
" def __init__(self):\n",
" super(MyModule, self).__init__()\n",
" self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n",
"\n",
" def forward(self, x):\n",
" # ModuleList can act as an iterable, or be indexed using ints\n",
" for i, l in enumerate(self.linears):\n",
" x = self.linears[i // 2](x) + l(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"net1:\n",
"torch.Size([10, 10])\n",
"torch.Size([10])\n",
"net2:\n"
]
}
],
"source": [
"class Module_ModuleList(nn.Module):\n",
" def __init__(self):\n",
" super(Module_ModuleList, self).__init__()\n",
" self.linears = nn.ModuleList([nn.Linear(10, 10)])\n",
" \n",
"class Module_List(nn.Module):\n",
" def __init__(self):\n",
" super(Module_List, self).__init__()\n",
" self.linears = [nn.Linear(10, 10)]\n",
"\n",
"net1 = Module_ModuleList()\n",
"net2 = Module_List()\n",
"\n",
"print(\"net1:\")\n",
"for p in net1.parameters():\n",
" print(p.size())\n",
"\n",
"print(\"net2:\")\n",
"for p in net2.parameters():\n",
" print(p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
...
...
@@ -208,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count":
7
,
"execution_count":
10
,
"metadata": {},
"outputs": [
{
...
...
@@ -236,6 +306,15 @@
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
]
},
{
"cell_type": "markdown",
"metadata": {},
...
...
@@ -245,7 +324,7 @@
},
{
"cell_type": "code",
"execution_count":
8
,
"execution_count":
12
,
"metadata": {
"collapsed": true
},
...
...
@@ -275,7 +354,7 @@
},
{
"cell_type": "code",
"execution_count":
9
,
"execution_count":
13
,
"metadata": {},
"outputs": [
{
...
...
@@ -290,10 +369,10 @@
{
"data": {
"text/plain": [
"tensor(
12.1594
, grad_fn=<SumBackward0>)"
"tensor(
0.8907
, grad_fn=<SumBackward0>)"
]
},
"execution_count":
9
,
"execution_count":
13
,
"metadata": {},
"output_type": "execute_result"
}
...
...
@@ -307,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 1
0
,
"execution_count": 1
4
,
"metadata": {},
"outputs": [
{
...
...
@@ -331,10 +410,10 @@
{
"data": {
"text/plain": [
"tensor(
0.1509
, grad_fn=<SumBackward0>)"
"tensor(
-0.4605
, grad_fn=<SumBackward0>)"
]
},
"execution_count": 1
0
,
"execution_count": 1
4
,
"metadata": {},
"output_type": "execute_result"
}
...
...
@@ -367,7 +446,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python
[default]
",
"display_name": "Python
3
",
"language": "python",
"name": "python3"
},
...
...
@@ -381,7 +460,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.
3
"
"version": "3.6.
2
"
}
},
"nbformat": 4,
...
...
docs/chapter04_DL_computation/4.1_model-construction.md
浏览文件 @
9f8cd357
...
...
@@ -114,6 +114,7 @@ net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net
.
append
(
nn
.
Linear
(
256
,
10
))
# # 类似List的append操作
print
(
net
[
-
1
])
# 类似List的索引访问
print
(
net
)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
```
输出:
```
...
...
@@ -125,6 +126,55 @@ ModuleList(
)
```
既然
`Sequential`
和
`ModuleList`
都可以进行列表化构造网络,那二者区别是什么呢。
`ModuleList`
仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现
`forward`
功能需要自己实现,所以上面执行
`net(torch.zeros(1, 784))`
会报
`NotImplementedError`
;而
`Sequential`
内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部
`forward`
功能已经实现。
`ModuleList`
的出现只是让网络定义前向传播时更加灵活,见下面官网的例子。
```
python
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
linears
=
nn
.
ModuleList
([
nn
.
Linear
(
10
,
10
)
for
i
in
range
(
10
)])
def
forward
(
self
,
x
):
# ModuleList can act as an iterable, or be indexed using ints
for
i
,
l
in
enumerate
(
self
.
linears
):
x
=
self
.
linears
[
i
//
2
](
x
)
+
l
(
x
)
return
x
```
另外,
`ModuleList`
不同于一般的Python的
`list`
,加入到
`ModuleList`
里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。
```
python
class
Module_ModuleList
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Module_ModuleList
,
self
).
__init__
()
self
.
linears
=
nn
.
ModuleList
([
nn
.
Linear
(
10
,
10
)])
class
Module_List
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Module_List
,
self
).
__init__
()
self
.
linears
=
[
nn
.
Linear
(
10
,
10
)]
net1
=
Module_ModuleList
()
net2
=
Module_List
()
print
(
"net1:"
)
for
p
in
net1
.
parameters
():
print
(
p
.
size
())
print
(
"net2:"
)
for
p
in
net2
.
parameters
():
print
(
p
)
```
输出:
```
net1:
torch.Size([10, 10])
torch.Size([10])
net2:
```
### 4.1.2.3 `ModuleDict`类
`ModuleDict`
接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:
```
python
...
...
@@ -136,6 +186,7 @@ net['output'] = nn.Linear(256, 10) # 添加
print
(
net
[
'linear'
])
# 访问
print
(
net
.
output
)
print
(
net
)
# net(torch.zeros(1, 784)) # 会报NotImplementedError
```
输出:
```
...
...
@@ -148,6 +199,7 @@ ModuleDict(
)
```
和
`ModuleList`
一样,
`ModuleDict`
实例仅仅是存放了一些模块的字典,并没有定义
`forward`
函数需要自己定义。同样,
`ModuleDict`
也与Python的
`Dict`
有所不同,
`ModuleDict`
里的所有模块的参数会被自动添加到整个网络中。
## 4.1.3 构造复杂的模型
...
...
@@ -230,6 +282,7 @@ tensor(14.4908, grad_fn=<SumBackward0>)
*
可以通过继承
`Module`
类来构造模型。
*
`Sequential`
、
`ModuleList`
、
`ModuleDict`
类都继承自
`Module`
类。
*
与
`Sequential`
不同,
`ModuleList`
和
`ModuleDict`
并没有定义一个完整的网络,它们只是将不同的模块存放在一起,需要自己定义
`forward`
函数。
*
虽然
`Sequential`
等类可以使模型构造更加简单,但直接继承
`Module`
类可以极大地拓展模型构造的灵活性。
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录