提交 962fc4ee 编写于 作者: B breezedeus

Revert "Revert "new arch for densenet: less pooling for height axis""

This reverts commit 8f41bdc4.
上级 e988dbf7
......@@ -140,8 +140,15 @@ class DenseNet(HybridBlock):
# self.output = nn.Dense(classes)
def hybrid_forward(self, F, x):
x = self.features(x)
# x = self.output(x)
"""
:param F:
:param x: with shape (batch_size, channels, img_height, img_width)
:return: with shape (batch_size, embed_size, 1, img_width // 4)
"""
x = self.features(x) # res: (batch_size, embed_size, 2, img_width // 4)
x = F.reshape(x, (0, -3, 0)) # res: (batch_size, embed_size * 2, img_width // 4)
x = F.expand_dims(x, axis=2) # res: (batch_size, embed_size * 2, 1, img_width // 4)
return x
......@@ -213,7 +220,10 @@ def _make_final_stage_net(stage_index, out_channels):
features.add(nn.Conv2D(out_channels // 4, kernel_size=1, use_bias=False))
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
features.add(nn.Conv2D(out_channels, kernel_size=(4, 1), use_bias=False))
features.add(
nn.Conv2D(out_channels, kernel_size=(2, 1), strides=(2, 1), use_bias=False)
)
features.add(nn.BatchNorm())
features.add(nn.Activation('relu'))
# features.add(nn.MaxPool2D(pool_size=(2, 1), strides=(2, 1)))
return features
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册