提交 c749bdf9 编写于 作者: W weishengyu

replace reshape with Flatten layer in ResNet

上级 1a74e9cb
...@@ -311,7 +311,7 @@ class ResNet(TheseusLayer): ...@@ -311,7 +311,7 @@ class ResNet(TheseusLayer):
self.blocks = nn.Sequential(*block_list) self.blocks = nn.Sequential(*block_list)
self.avg_pool = AdaptiveAvgPool2D(1) self.avg_pool = AdaptiveAvgPool2D(1)
self.avg_pool_channels = self.num_channels[-1] * 2 self.flatten = nn.Flatten()
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0) stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
self.fc = Linear( self.fc = Linear(
...@@ -324,7 +324,7 @@ class ResNet(TheseusLayer): ...@@ -324,7 +324,7 @@ class ResNet(TheseusLayer):
x = self.max_pool(x) x = self.max_pool(x)
x = self.blocks(x) x = self.blocks(x)
x = self.avg_pool(x) x = self.avg_pool(x)
x = paddle.reshape(x, shape=[-1, self.avg_pool_channels]) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
return x return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册