提交 70641360 编写于 作者: W weishengyu

dbg ghostnet

上级 7c9e695f
......@@ -27,6 +27,7 @@ from .hrnet import HRNet_W18_C
from .efficientnet import EfficientNetB0
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
from .googlenet import GoogLeNet
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
......
......@@ -20,7 +20,6 @@ import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle.nn.initializer import Uniform
from paddle import fluid
class ConvBNLayer(nn.Layer):
......@@ -42,9 +41,12 @@ class ConvBNLayer(nn.Layer):
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
weight_attr=ParamAttr(initializer=nn.initializer.MSRA(), name=name + "_weights"),
bias_attr=False
)
bn_name = name + "_bn"
# In the old version, moving_variance_name was name + "_variance"
self._batch_norm = BatchNorm(
num_filters,
act=act,
......@@ -104,7 +106,7 @@ class SEBlock(nn.Layer):
squeeze = self.squeeze(pool)
squeeze = F.relu(squeeze)
excitation = self.excitation(squeeze)
excitation = F.sigmoid(excitation)
excitation = paddle.fluid.layers.clip(x=excitation, min=0, max=1)
excitation = paddle.reshape(
excitation,
shape=[-1, self._num_channels, 1, 1]
......@@ -138,7 +140,7 @@ class GhostModule(nn.Layer):
name=name + "_primary_conv"
)
self.cheap_operation = ConvBNLayer(
num_channels=num_channels,
num_channels=init_channels,
num_filters=new_channels,
filter_size=dw_size,
stride=1,
......@@ -186,7 +188,7 @@ class GhostBottleneck(nn.Layer):
stride=stride,
groups=hidden_dim,
act=None,
name=name+"_depthwise"
name=name+"_depthwise" # In the old version, name was name + "_depthwise_depthwise"
)
if use_se:
self.se_block = SEBlock(
......@@ -194,7 +196,7 @@ class GhostBottleneck(nn.Layer):
name=name + "_se"
)
self.ghost_module_2 = GhostModule(
num_channels=num_channels,
num_channels=hidden_dim,
output_channels=output_channels,
kernel_size=1,
relu=False,
......@@ -208,7 +210,7 @@ class GhostBottleneck(nn.Layer):
stride=stride,
groups=num_channels,
act=None,
name=name + "_shotcut_depthwise"
name=name + "_shortcut_depthwise" # In the old version, name was name + "_shortcut_depthwise_depthwise"
)
self.shortcut_conv = ConvBNLayer(
num_channels=num_channels,
......@@ -217,11 +219,11 @@ class GhostBottleneck(nn.Layer):
stride=1,
groups=1,
act=None,
name=name + "_shotcut_conv"
name=name + "_shortcut_conv"
)
def forward(self, inputs):
x = self.ghost_module(inputs)
x = self.ghost_module_1(inputs)
if self._stride == 2:
x = self.depthwise_conv(x)
if self._use_se:
......@@ -275,14 +277,17 @@ class GhostNet(nn.Layer):
num_channels = output_channels
output_channels = int(self._make_divisible(c * self.scale, 4))
hidden_dim = int(self._make_divisible(exp_size, self.scale, 4))
ghost_bottleneck = GhostBottleneck(
num_channels=num_channels,
hidden_dim=hidden_dim,
output_channels=output_channels,
kernel_size=k,
stride=s,
use_se=use_se,
name="_ghostbottleneck" + str(idx)
ghost_bottleneck = self.add_sublayer(
name="_ghostbottleneck_" + str(idx),
sublayer=GhostBottleneck(
num_channels=num_channels,
hidden_dim=hidden_dim,
output_channels=output_channels,
kernel_size=k,
stride=s,
use_se=use_se,
name="_ghostbottleneck_" + str(idx)
)
)
self.ghost_bottleneck_list.append(ghost_bottleneck)
idx += 1
......@@ -300,24 +305,26 @@ class GhostNet(nn.Layer):
)
self.pool2d_gap = AdaptiveAvgPool2d(1)
num_channels = output_channels
output_channels = 1280
self._num_channels = num_channels
self._fc0_output_channels = 1280
self.fc_0 = ConvBNLayer(
num_channels=num_channels,
num_filters=output_channels,
num_filters=self._fc0_output_channels,
filter_size=1,
stride=1,
act="relu",
name="fc_0"
)
self.dropout = nn.Dropout(p=0.2)
stdv = 1.0 / math.sqrt(output_channels * 1.0)
stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
self.fc_1 = Linear(
output_channels,
self._fc0_output_channels,
class_dim,
param_attr=ParamAttr(
weight_attr=ParamAttr(
name="fc_1_weights",
initializer=Uniform(-stdv, stdv)
)
),
bias_attr=ParamAttr(name="fc_1_offset")
)
def forward(self, inputs):
......@@ -328,6 +335,7 @@ class GhostNet(nn.Layer):
x = self.pool2d_gap(x)
x = self.fc_0(x)
x = self.dropout(x)
x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
x = self.fc_1(x)
return x
......@@ -345,3 +353,18 @@ class GhostNet(nn.Layer):
if new_v < 0.9 * v:
new_v += divisor
return new_v
def GhostNet_x0_5():
model = GhostNet(scale=0.5)
return model
def GhostNet_x1_0():
model = GhostNet(scale=1.0)
return model
def GhostNet_x1_3():
model = GhostNet(scale=1.3)
return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册