diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 5d107fe242039565ebfa1b21940779d8dd8a26af..eaf2ce3d720420d357745bf0713848c347ce66f2 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -311,7 +311,7 @@ class ResNet(TheseusLayer): self.blocks = nn.Sequential(*block_list) self.avg_pool = AdaptiveAvgPool2D(1) - self.avg_pool_channels = self.num_channels[-1] * 2 + self.flatten = nn.Flatten() stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0) self.fc = Linear( @@ -324,7 +324,7 @@ class ResNet(TheseusLayer): x = self.max_pool(x) x = self.blocks(x) x = self.avg_pool(x) - x = paddle.reshape(x, shape=[-1, self.avg_pool_channels]) + x = self.flatten(x) x = self.fc(x) return x