diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 5d107fe242039565ebfa1b21940779d8dd8a26af..e2453f8dd53287831f6af5e1d1dc3b3f685b1cb6 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -311,8 +311,8 @@ class ResNet(TheseusLayer): self.blocks = nn.Sequential(*block_list) self.avg_pool = AdaptiveAvgPool2D(1) + self.flatten = nn.Flatten() self.avg_pool_channels = self.num_channels[-1] * 2 - stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0) self.fc = Linear( self.avg_pool_channels, @@ -324,7 +324,7 @@ class ResNet(TheseusLayer): x = self.max_pool(x) x = self.blocks(x) x = self.avg_pool(x) - x = paddle.reshape(x, shape=[-1, self.avg_pool_channels]) + x = self.flatten(x) x = self.fc(x) return x