diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py index 5ef34e2dbee0471e71116d558a9040ea041202f1..750b488ff941e8b2e2624a7e704345f9ca690920 100644 --- a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -256,10 +256,10 @@ class PPLCNetV2(TheseusLayer): depths=[2, 2, 6, 2], class_num=1000, dropout_prob=0.2, - class_expand=1280): + expand=1280): super().__init__() self.scale = scale - self.class_expand = class_expand + self.expand = expand self.stem = nn.Sequential(* [ ConvBNLayer( @@ -339,19 +339,22 @@ class PPLCNetV2(TheseusLayer): self.avg_pool = AdaptiveAvgPool2D(1) - self.last_conv = Conv2D( - in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 * scale), - out_channels=self.class_expand, - kernel_size=1, - stride=1, - padding=0, - bias_attr=False) + if self.expand: + self.last_conv = Conv2D( + in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 * + scale), + out_channels=self.expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.act = nn.ReLU() + self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") - self.act = nn.ReLU() - self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) - - self.fc = Linear(self.class_expand, class_num) + in_features = self.expand if self.expand else NET_CONFIG["stage4"][ + 0] * 2 * scale + self.fc = Linear(in_features, class_num) def forward(self, x): x = self.stem(x) @@ -360,9 +363,10 @@ class PPLCNetV2(TheseusLayer): x = self.stage3(x) x = self.stage4(x) x = self.avg_pool(x) - x = self.last_conv(x) - x = self.act(x) - x = self.dropout(x) + if self.expand: + x = self.last_conv(x) + x = self.act(x) + x = self.dropout(x) x = self.flatten(x) x = self.fc(x) return x @@ -393,4 +397,4 @@ def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs): """ model = PPLCNetV2(scale=1.0, depths=[2, 2, 6, 2], **kwargs) _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) - return model \ No newline at end of file + return model