提交 447ebeaf 编写于 作者: S ShusenTang

update 5.8

上级 636b1966
...@@ -10,6 +10,8 @@ import zipfile ...@@ -10,6 +10,8 @@ import zipfile
from IPython import display from IPython import display
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import torch import torch
from torch import nn
import torch.nn.functional as F
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
# import mxnet as mx # import mxnet as mx
...@@ -258,3 +260,13 @@ def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNI ...@@ -258,3 +260,13 @@ def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNI
############################# 5.8 ##############################
class GlobalAvgPool2d(nn.Module):
# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
return F.avg_pool2d(x, kernel_size=x.size()[2:])
...@@ -36,6 +36,14 @@ NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之 ...@@ -36,6 +36,14 @@ NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之
除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。 除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。
``` python ``` python
# 已保存在d2lzh_pytorch
class GlobalAvgPool2d(nn.Module):
# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
return F.avg_pool2d(x, kernel_size=x.size()[2:])
net = nn.Sequential( net = nn.Sequential(
nin_block(1, 96, kernel_size=11, stride=4, padding=0), nin_block(1, 96, kernel_size=11, stride=4, padding=0),
nn.MaxPool2d(kernel_size=3, stride=2), nn.MaxPool2d(kernel_size=3, stride=2),
...@@ -46,8 +54,7 @@ net = nn.Sequential( ...@@ -46,8 +54,7 @@ net = nn.Sequential(
nn.Dropout(0.5), nn.Dropout(0.5),
# 标签类别数是10 # 标签类别数是10
nin_block(384, 10, kernel_size=3, stride=1, padding=1), nin_block(384, 10, kernel_size=3, stride=1, padding=1),
# 全局平均池化层可通过将窗口形状设置成输入的高和宽实现 GlobalAvgPool2d(),
nn.AvgPool2d(kernel_size=5),
# 将四维的输出转成二维的输出,其形状为(批量大小, 10) # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
d2l.FlattenLayer()) d2l.FlattenLayer())
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册