提交 9af9f0b3 编写于 作者: A Aston Zhang

update resnet in cifar

上级 ab28f55a
......@@ -184,30 +184,29 @@ test_data = gdata.DataLoader(test_ds.transform_first(transform_test),
## 定义模型
我们这里使用了ResNet-18模型,并使用混合式编程来提升执行效率。
我们在这里定义ResNet-18模型,并使用混合式编程来提升执行效率。
```{.python .input n=6}
class Residual(nn.HybridBlock):
def __init__(self, channels, same_shape=True, **kwargs):
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(Residual, self).__init__(**kwargs)
self.same_shape = same_shape
with self.name_scope():
strides = 1 if same_shape else 2
self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1,
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
self.bn1 = nn.BatchNorm()
self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm()
if not same_shape:
self.conv3 = nn.Conv2D(channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
def hybrid_forward(self, F, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if not self.same_shape:
x = self.conv3(x)
return F.relu(out + x)
def hybrid_forward(self, F, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return F.relu(Y + X)
class ResNet(nn.HybridBlock):
......@@ -223,15 +222,15 @@ class ResNet(nn.HybridBlock):
net.add(nn.Activation(activation='relu'))
# 模块 2。
for _ in range(3):
net.add(Residual(channels=32))
net.add(Residual(num_channels=32))
# 模块 3。
net.add(Residual(channels=64, same_shape=False))
net.add(Residual(num_channels=64, use_1x1conv=True, strides=2))
for _ in range(2):
net.add(Residual(channels=64))
net.add(Residual(num_channels=64))
# 模块 4。
net.add(Residual(channels=128, same_shape=False))
net.add(Residual(num_channels=128, use_1x1conv=True, strides=2))
for _ in range(2):
net.add(Residual(channels=128))
net.add(Residual(num_channels=128))
# 模块 5。
net.add(nn.AvgPool2D(pool_size=8))
net.add(nn.Flatten())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册