提交 70641360 编写于 作者: W weishengyu

dbg ghostnet

上级 7c9e695f
...@@ -27,6 +27,7 @@ from .hrnet import HRNet_W18_C ...@@ -27,6 +27,7 @@ from .hrnet import HRNet_W18_C
from .efficientnet import EfficientNetB0 from .efficientnet import EfficientNetB0
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50 from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
from .googlenet import GoogLeNet from .googlenet import GoogLeNet
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1 from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
......
...@@ -20,7 +20,6 @@ import paddle.nn.functional as F ...@@ -20,7 +20,6 @@ import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear
from paddle.fluid.regularizer import L2DecayRegularizer from paddle.fluid.regularizer import L2DecayRegularizer
from paddle.nn.initializer import Uniform from paddle.nn.initializer import Uniform
from paddle import fluid
class ConvBNLayer(nn.Layer): class ConvBNLayer(nn.Layer):
...@@ -42,9 +41,12 @@ class ConvBNLayer(nn.Layer): ...@@ -42,9 +41,12 @@ class ConvBNLayer(nn.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(initializer=nn.initializer.MSRA(), name=name + "_weights"),
bias_attr=False) bias_attr=False
)
bn_name = name + "_bn" bn_name = name + "_bn"
# In the old version, moving_variance_name was name + "_variance"
self._batch_norm = BatchNorm( self._batch_norm = BatchNorm(
num_filters, num_filters,
act=act, act=act,
...@@ -104,7 +106,7 @@ class SEBlock(nn.Layer): ...@@ -104,7 +106,7 @@ class SEBlock(nn.Layer):
squeeze = self.squeeze(pool) squeeze = self.squeeze(pool)
squeeze = F.relu(squeeze) squeeze = F.relu(squeeze)
excitation = self.excitation(squeeze) excitation = self.excitation(squeeze)
excitation = F.sigmoid(excitation) excitation = paddle.fluid.layers.clip(x=excitation, min=0, max=1)
excitation = paddle.reshape( excitation = paddle.reshape(
excitation, excitation,
shape=[-1, self._num_channels, 1, 1] shape=[-1, self._num_channels, 1, 1]
...@@ -138,7 +140,7 @@ class GhostModule(nn.Layer): ...@@ -138,7 +140,7 @@ class GhostModule(nn.Layer):
name=name + "_primary_conv" name=name + "_primary_conv"
) )
self.cheap_operation = ConvBNLayer( self.cheap_operation = ConvBNLayer(
num_channels=num_channels, num_channels=init_channels,
num_filters=new_channels, num_filters=new_channels,
filter_size=dw_size, filter_size=dw_size,
stride=1, stride=1,
...@@ -186,7 +188,7 @@ class GhostBottleneck(nn.Layer): ...@@ -186,7 +188,7 @@ class GhostBottleneck(nn.Layer):
stride=stride, stride=stride,
groups=hidden_dim, groups=hidden_dim,
act=None, act=None,
name=name+"_depthwise" name=name+"_depthwise" # In the old version, name was name + "_depthwise_depthwise"
) )
if use_se: if use_se:
self.se_block = SEBlock( self.se_block = SEBlock(
...@@ -194,7 +196,7 @@ class GhostBottleneck(nn.Layer): ...@@ -194,7 +196,7 @@ class GhostBottleneck(nn.Layer):
name=name + "_se" name=name + "_se"
) )
self.ghost_module_2 = GhostModule( self.ghost_module_2 = GhostModule(
num_channels=num_channels, num_channels=hidden_dim,
output_channels=output_channels, output_channels=output_channels,
kernel_size=1, kernel_size=1,
relu=False, relu=False,
...@@ -208,7 +210,7 @@ class GhostBottleneck(nn.Layer): ...@@ -208,7 +210,7 @@ class GhostBottleneck(nn.Layer):
stride=stride, stride=stride,
groups=num_channels, groups=num_channels,
act=None, act=None,
name=name + "_shotcut_depthwise" name=name + "_shortcut_depthwise" # In the old version, name was name + "_shortcut_depthwise_depthwise"
) )
self.shortcut_conv = ConvBNLayer( self.shortcut_conv = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
...@@ -217,11 +219,11 @@ class GhostBottleneck(nn.Layer): ...@@ -217,11 +219,11 @@ class GhostBottleneck(nn.Layer):
stride=1, stride=1,
groups=1, groups=1,
act=None, act=None,
name=name + "_shotcut_conv" name=name + "_shortcut_conv"
) )
def forward(self, inputs): def forward(self, inputs):
x = self.ghost_module(inputs) x = self.ghost_module_1(inputs)
if self._stride == 2: if self._stride == 2:
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
if self._use_se: if self._use_se:
...@@ -275,14 +277,17 @@ class GhostNet(nn.Layer): ...@@ -275,14 +277,17 @@ class GhostNet(nn.Layer):
num_channels = output_channels num_channels = output_channels
output_channels = int(self._make_divisible(c * self.scale, 4)) output_channels = int(self._make_divisible(c * self.scale, 4))
hidden_dim = int(self._make_divisible(exp_size, self.scale, 4)) hidden_dim = int(self._make_divisible(exp_size, self.scale, 4))
ghost_bottleneck = GhostBottleneck( ghost_bottleneck = self.add_sublayer(
num_channels=num_channels, name="_ghostbottleneck_" + str(idx),
hidden_dim=hidden_dim, sublayer=GhostBottleneck(
output_channels=output_channels, num_channels=num_channels,
kernel_size=k, hidden_dim=hidden_dim,
stride=s, output_channels=output_channels,
use_se=use_se, kernel_size=k,
name="_ghostbottleneck" + str(idx) stride=s,
use_se=use_se,
name="_ghostbottleneck_" + str(idx)
)
) )
self.ghost_bottleneck_list.append(ghost_bottleneck) self.ghost_bottleneck_list.append(ghost_bottleneck)
idx += 1 idx += 1
...@@ -300,24 +305,26 @@ class GhostNet(nn.Layer): ...@@ -300,24 +305,26 @@ class GhostNet(nn.Layer):
) )
self.pool2d_gap = AdaptiveAvgPool2d(1) self.pool2d_gap = AdaptiveAvgPool2d(1)
num_channels = output_channels num_channels = output_channels
output_channels = 1280 self._num_channels = num_channels
self._fc0_output_channels = 1280
self.fc_0 = ConvBNLayer( self.fc_0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=output_channels, num_filters=self._fc0_output_channels,
filter_size=1, filter_size=1,
stride=1, stride=1,
act="relu", act="relu",
name="fc_0" name="fc_0"
) )
self.dropout = nn.Dropout(p=0.2) self.dropout = nn.Dropout(p=0.2)
stdv = 1.0 / math.sqrt(output_channels * 1.0) stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
self.fc_1 = Linear( self.fc_1 = Linear(
output_channels, self._fc0_output_channels,
class_dim, class_dim,
param_attr=ParamAttr( weight_attr=ParamAttr(
name="fc_1_weights", name="fc_1_weights",
initializer=Uniform(-stdv, stdv) initializer=Uniform(-stdv, stdv)
) ),
bias_attr=ParamAttr(name="fc_1_offset")
) )
def forward(self, inputs): def forward(self, inputs):
...@@ -328,6 +335,7 @@ class GhostNet(nn.Layer): ...@@ -328,6 +335,7 @@ class GhostNet(nn.Layer):
x = self.pool2d_gap(x) x = self.pool2d_gap(x)
x = self.fc_0(x) x = self.fc_0(x)
x = self.dropout(x) x = self.dropout(x)
x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
x = self.fc_1(x) x = self.fc_1(x)
return x return x
...@@ -345,3 +353,18 @@ class GhostNet(nn.Layer): ...@@ -345,3 +353,18 @@ class GhostNet(nn.Layer):
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
def GhostNet_x0_5():
model = GhostNet(scale=0.5)
return model
def GhostNet_x1_0():
model = GhostNet(scale=1.0)
return model
def GhostNet_x1_3():
model = GhostNet(scale=1.3)
return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册