提交 8f41bdc4 编写于 作者: B breezedeus

Revert "new arch for densenet: less pooling for height axis"

This reverts commit 4e57dc56.
上级 4e57dc56
......@@ -140,15 +140,8 @@ class DenseNet(HybridBlock):
# self.output = nn.Dense(classes)
def hybrid_forward(self, F, x):
"""
:param F:
:param x: with shape (batch_size, channels, img_height, img_width)
:return: with shape (batch_size, embed_size, 1, img_width // 4)
"""
x = self.features(x) # res: (batch_size, embed_size, 2, img_width // 4)
x = F.reshape(x, (0, -3, 0)) # res: (batch_size, embed_size * 2, img_width // 4)
x = F.expand_dims(x, axis=2) # res: (batch_size, embed_size * 2, 1, img_width // 4)
x = self.features(x)
# x = self.output(x)
return x
......@@ -220,10 +213,7 @@ def _make_final_stage_net(stage_index, out_channels):
features.add(nn.Conv2D(out_channels // 4, kernel_size=1, use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(
nn.Conv2D(out_channels, kernel_size=(2, 1), strides=(2, 1), use_bias=False)
)
features.add(nn.Conv2D(out_channels, kernel_size=(4, 1), use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
# features.add(nn.MaxPool2D(pool_size=(2, 1), strides=(2, 1)))
return features
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册