提交 aab87e89 编写于 作者: S ShusenTang

fix bug on train_pytorch_ch7

上级 23ffe9e8
......@@ -145,7 +145,7 @@ loss: 0.242805, 0.078792 sec per epoch
在PyTorch里可以通过创建`optimizer`实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数`optimizer_fn`和超参数`optimizer_hyperparams`来创建`optimizer`实例。
``` python
# 本函数已保存在d2lzh_pytorch包中方便以后使用 (与原书不同的是这里第一个参数优化器函数而不是优化器的名字)
# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
batch_size=10, num_epochs=2):
......@@ -157,7 +157,7 @@ def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)
def eval_loss():
return loss(net(features), labels).mean().item()
return loss(net(features).view(-1), labels).item() / 2
ls = [eval_loss()]
data_iter = torch.utils.data.DataLoader(
......@@ -166,7 +166,8 @@ def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
for _ in range(num_epochs):
start = time.time()
for batch_i, (X, y) in enumerate(data_iter):
l = loss(net(X), y)
# 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2
l = loss(net(X).view(-1), y) / 2
optimizer.zero_grad()
l.backward()
......@@ -186,7 +187,7 @@ train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)
```
输出:
```
loss: 0.242819, 0.192941 sec per epoch
loss: 0.245491, 0.044150 sec per epoch
```
<div align=center>
<img width="300" src="../../img/chapter07/7.3_output4.png"/>
......
......@@ -206,7 +206,7 @@ d2l.train_pytorch_ch7(torch.optim.SGD, {'lr': 0.004, 'momentum': 0.9},
```
输出:
```
loss: 0.994468, 0.083361 sec per epoch
loss: 0.253280, 0.060247 sec per epoch
```
<div align=center>
<img width="300" src="../../img/chapter07/7.4_output8.png"/>
......
......@@ -116,7 +116,7 @@ d2l.train_pytorch_ch7(torch.optim.Adagrad, {'lr': 0.1}, features, labels)
```
输出:
```
loss: 0.989744, 0.089780 sec per epoch
loss: 0.243147, 0.040675 sec per epoch
```
<div align=center>
......
......@@ -92,7 +92,7 @@ d2l.train_pytorch_ch7(torch.optim.RMSprop, {'lr': 0.01, 'alpha': 0.9},
输出:
```
loss: 0.995220, 0.083705 sec per epoch
loss: 0.243676, 0.043637 sec per epoch
```
<div align=center>
<img width="300" src="../../img/chapter07/7.6_output3.png"/>
......
......@@ -73,7 +73,7 @@ d2l.train_pytorch_ch7(torch.optim.Adadelta, {'rho': 0.9}, features, labels)
输出:
```
loss: 0.996619, 0.105015 sec per epoch
loss: 0.242104, 0.047702 sec per epoch
```
<div align=center>
<img width="300" src="../../img/chapter07/7.7_output2.png"/>
......
img/chapter07/7.3_output4.png

29.4 KB | W: | H:

img/chapter07/7.3_output4.png

31.4 KB | W: | H:

img/chapter07/7.3_output4.png
img/chapter07/7.3_output4.png
img/chapter07/7.3_output4.png
img/chapter07/7.3_output4.png
  • 2-up
  • Swipe
  • Onion skin
img/chapter07/7.4_output8.png

28.7 KB | W: | H:

img/chapter07/7.4_output8.png

27.9 KB | W: | H:

img/chapter07/7.4_output8.png
img/chapter07/7.4_output8.png
img/chapter07/7.4_output8.png
img/chapter07/7.4_output8.png
  • 2-up
  • Swipe
  • Onion skin
img/chapter07/7.5_output4.png

26.7 KB | W: | H:

img/chapter07/7.5_output4.png

27.4 KB | W: | H:

img/chapter07/7.5_output4.png
img/chapter07/7.5_output4.png
img/chapter07/7.5_output4.png
img/chapter07/7.5_output4.png
  • 2-up
  • Swipe
  • Onion skin
img/chapter07/7.6_output3.png

28.0 KB | W: | H:

img/chapter07/7.6_output3.png

28.1 KB | W: | H:

img/chapter07/7.6_output3.png
img/chapter07/7.6_output3.png
img/chapter07/7.6_output3.png
img/chapter07/7.6_output3.png
  • 2-up
  • Swipe
  • Onion skin
img/chapter07/7.7_output2.png

30.6 KB | W: | H:

img/chapter07/7.7_output2.png

31.4 KB | W: | H:

img/chapter07/7.7_output2.png
img/chapter07/7.7_output2.png
img/chapter07/7.7_output2.png
img/chapter07/7.7_output2.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册