未验证 提交 1020f418 编写于 作者: W Wei Shengyu 提交者: GitHub

use flatten instead of reshape for hrnet (#1709)

上级 7595ba6d
...@@ -459,6 +459,7 @@ class HRNet(TheseusLayer): ...@@ -459,6 +459,7 @@ class HRNet(TheseusLayer):
self.avg_pool = nn.AdaptiveAvgPool2D(1) self.avg_pool = nn.AdaptiveAvgPool2D(1)
stdv = 1.0 / math.sqrt(2048 * 1.0) stdv = 1.0 / math.sqrt(2048 * 1.0)
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
self.fc = nn.Linear( self.fc = nn.Linear(
2048, 2048,
...@@ -496,7 +497,7 @@ class HRNet(TheseusLayer): ...@@ -496,7 +497,7 @@ class HRNet(TheseusLayer):
y = self.conv_last(y) y = self.conv_last(y)
y = self.avg_pool(y) y = self.avg_pool(y)
y = paddle.reshape(y, shape=[-1, y.shape[1]]) y = self.flatten(y)
y = self.fc(y) y = self.fc(y)
return y return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册