未验证 提交 0ef1ac3f 编写于 作者: G gaotingquan

fix

上级 6e0900ec
...@@ -256,10 +256,10 @@ class PPLCNetV2(TheseusLayer): ...@@ -256,10 +256,10 @@ class PPLCNetV2(TheseusLayer):
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],
class_num=1000, class_num=1000,
dropout_prob=0.2, dropout_prob=0.2,
class_expand=1280): expand=1280):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.class_expand = class_expand self.expand = expand
self.stem = nn.Sequential(* [ self.stem = nn.Sequential(* [
ConvBNLayer( ConvBNLayer(
...@@ -339,19 +339,22 @@ class PPLCNetV2(TheseusLayer): ...@@ -339,19 +339,22 @@ class PPLCNetV2(TheseusLayer):
self.avg_pool = AdaptiveAvgPool2D(1) self.avg_pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D( if self.expand:
in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 * scale), self.last_conv = Conv2D(
out_channels=self.class_expand, in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 *
kernel_size=1, scale),
stride=1, out_channels=self.expand,
padding=0, kernel_size=1,
bias_attr=False) stride=1,
padding=0,
bias_attr=False)
self.act = nn.ReLU()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.act = nn.ReLU()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
in_features = self.expand if self.expand else NET_CONFIG["stage4"][
self.fc = Linear(self.class_expand, class_num) 0] * 2 * scale
self.fc = Linear(in_features, class_num)
def forward(self, x): def forward(self, x):
x = self.stem(x) x = self.stem(x)
...@@ -360,9 +363,10 @@ class PPLCNetV2(TheseusLayer): ...@@ -360,9 +363,10 @@ class PPLCNetV2(TheseusLayer):
x = self.stage3(x) x = self.stage3(x)
x = self.stage4(x) x = self.stage4(x)
x = self.avg_pool(x) x = self.avg_pool(x)
x = self.last_conv(x) if self.expand:
x = self.act(x) x = self.last_conv(x)
x = self.dropout(x) x = self.act(x)
x = self.dropout(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc(x) x = self.fc(x)
return x return x
...@@ -393,4 +397,4 @@ def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs): ...@@ -393,4 +397,4 @@ def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs):
""" """
model = PPLCNetV2(scale=1.0, depths=[2, 2, 6, 2], **kwargs) model = PPLCNetV2(scale=1.0, depths=[2, 2, 6, 2], **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld)
return model return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册