提交 18d372ff 编写于 作者: W wqz960

fix format for ghostnet

上级 e8c3d72b
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import paddle import paddle.fluid as fluid
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.contrib.model_stat import summary
__all__ = ["GhostNet", "GhostNet_0_5", "GhostNet_1_0", "GhostNet_1_3"] __all__ = ["GhostNet", "GhostNet_0_5", "GhostNet_1_0", "GhostNet_1_3"]
class GhostNet(): class GhostNet():
def __init__(self, width_mult): def __init__(self, width_mult):
cfgs = [ cfgs = [
# k, t, c, SE, s # k, t, c, SE, s
[3, 16, 16, 0, 1], [3, 16, 16, 0, 1],
[3, 48, 24, 0, 2], [3, 48, 24, 0, 2],
[3, 72, 24, 0, 1], [3, 72, 24, 0, 1],
[5, 72, 40, 1, 2], [5, 72, 40, 1, 2],
[5, 120, 40, 1, 1], [5, 120, 40, 1, 1],
[3, 240, 80, 0, 2], [3, 240, 80, 0, 2],
[3, 200, 80, 0, 1], [3, 200, 80, 0, 1],
[3, 184, 80, 0, 1], [3, 184, 80, 0, 1],
[3, 184, 80, 0, 1], [3, 184, 80, 0, 1],
[3, 480, 112, 1, 1], [3, 480, 112, 1, 1],
[3, 672, 112, 1, 1], [3, 672, 112, 1, 1],
[5, 672, 160, 1, 2], [5, 672, 160, 1, 2],
[5, 960, 160, 0, 1], [5, 960, 160, 0, 1],
[5, 960, 160, 1, 1], [5, 960, 160, 1, 1],
[5, 960, 160, 0, 1], [5, 960, 160, 0, 1],
[5, 960, 160, 1, 1] [5, 960, 160, 1, 1]
] ]
self.cfgs = cfgs self.cfgs = cfgs
self.width_mult = width_mult self.width_mult = width_mult
def _make_divisible(self, v, divisor, min_value=None): def _make_divisible(self, v, divisor, min_value=None):
""" """
This function is taken from the original tf repo. This function is taken from the original tf repo.
...@@ -50,9 +48,9 @@ class GhostNet(): ...@@ -50,9 +48,9 @@ class GhostNet():
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
def conv_bn_layer(self, def conv_bn_layer(self,
input, input,
num_filters, num_filters,
filter_size, filter_size,
stride=1, stride=1,
...@@ -60,35 +58,38 @@ class GhostNet(): ...@@ -60,35 +58,38 @@ class GhostNet():
act=None, act=None,
name=None, name=None,
data_format="NCHW"): data_format="NCHW"):
x = fluid.layers.conv2d(input=input, x = fluid.layers.conv2d(
num_filters=num_filters, input=input,
filter_size=filter_size, num_filters=num_filters,
stride=stride, filter_size=filter_size,
padding=(filter_size-1)//2, stride=stride,
groups=groups, padding=(filter_size - 1) // 2,
act=None, groups=groups,
param_attr=ParamAttr( act=None,
initializer=fluid.initializer.MSRA(),name=name+"_weights"), param_attr=ParamAttr(
bias_attr=False, initializer=fluid.initializer.MSRA(), name=name + "_weights"),
name=name+"_conv_op", bias_attr=False,
data_format=data_format) name=name + "_conv_op",
data_format=data_format)
x = fluid.layers.batch_norm(input=x,
act=act, x = fluid.layers.batch_norm(
name=name+"_bn", input=x,
param_attr=ParamAttr(name=name+"_bn_scale", regularizer=fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)), act=act,
bias_attr=ParamAttr(name=name+"_bn_offset", regularizer=fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)), name=name + "_bn",
moving_mean_name=name+"_bn_mean", param_attr=ParamAttr(
moving_variance_name=name+"_bn_variance", name=name + "_bn_scale",
data_layout=data_format) regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=name + "_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance",
data_layout=data_format)
return x return x
def SElayer(self, input, num_channels, reduction_ratio=4, name=None):
def SElayer(self,
input,
num_channels,
reduction_ratio=4,
name=None):
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True) input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
...@@ -109,30 +110,29 @@ class GhostNet(): ...@@ -109,30 +110,29 @@ class GhostNet():
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'), name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset')) bias_attr=ParamAttr(name=name + '_exc_offset'))
excitation = fluid.layers.clip(x=excitation, excitation = fluid.layers.clip(
min=0, x=excitation, min=0, max=1, name=name + '_clip')
max=1,
name=name+'_clip')
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale return scale
def depthwise_conv(self, def depthwise_conv(self,
inp, inp,
oup, oup,
kernel_size, kernel_size,
stride=1, stride=1,
relu=False, relu=False,
name=None, name=None,
data_format="NCHW"): data_format="NCHW"):
return self.conv_bn_layer(input=inp, return self.conv_bn_layer(
num_filters=oup, input=inp,
filter_size=kernel_size, num_filters=oup,
stride=stride, filter_size=kernel_size,
groups=inp.shape[1] if data_format=="NCHW" else inp.shape[-1], stride=stride,
act="relu" if relu else None, groups=inp.shape[1] if data_format == "NCHW" else inp.shape[-1],
name=name+"_dw", act="relu" if relu else None,
data_format=data_format) name=name + "_dw",
data_format=data_format)
def GhostModule(self, def GhostModule(self,
inp, inp,
oup, oup,
...@@ -143,170 +143,184 @@ class GhostNet(): ...@@ -143,170 +143,184 @@ class GhostNet():
relu=True, relu=True,
name=None, name=None,
data_format="NCHW"): data_format="NCHW"):
self.oup=oup self.oup = oup
init_channels = int(math.ceil(oup/ratio)) init_channels = int(math.ceil(oup / ratio))
new_channels = int(init_channels*(ratio-1)) new_channels = int(init_channels * (ratio - 1))
primary_conv = self.conv_bn_layer(input=inp, primary_conv = self.conv_bn_layer(
num_filters=init_channels, input=inp,
filter_size=kernel_size, num_filters=init_channels,
stride=stride, filter_size=kernel_size,
groups=1, stride=stride,
act="relu" if relu else None, groups=1,
name=name+"_primary_conv", act="relu" if relu else None,
data_format="NCHW") name=name + "_primary_conv",
cheap_operation = self.conv_bn_layer(input=primary_conv, data_format="NCHW")
num_filters=new_channels, cheap_operation = self.conv_bn_layer(
filter_size=dw_size, input=primary_conv,
stride=1, num_filters=new_channels,
groups=init_channels, filter_size=dw_size,
act="relu" if relu else None, stride=1,
name=name+"_cheap_operation", groups=init_channels,
data_format=data_format) act="relu" if relu else None,
out = fluid.layers.concat([primary_conv, cheap_operation], axis=1, name=name+"_concat") name=name + "_cheap_operation",
data_format=data_format)
out = fluid.layers.concat(
[primary_conv, cheap_operation], axis=1, name=name + "_concat")
return out return out
def GhostBottleneck(self, def GhostBottleneck(self,
inp, inp,
hidden_dim, hidden_dim,
oup, oup,
kernel_size, kernel_size,
stride, stride,
use_se, use_se,
name=None, name=None,
data_format="NCHW"): data_format="NCHW"):
inp_channels = inp.shape[1] inp_channels = inp.shape[1]
x = self.GhostModule(inp=inp, x = self.GhostModule(
oup=hidden_dim, inp=inp,
kernel_size=1, oup=hidden_dim,
stride=1, kernel_size=1,
relu=True, stride=1,
name=name+"GhostBottle_1", relu=True,
data_format="NCHW") name=name + "GhostBottle_1",
if stride==2: data_format="NCHW")
x = self.depthwise_conv(inp=x, if stride == 2:
oup=hidden_dim, x = self.depthwise_conv(
kernel_size=kernel_size, inp=x,
stride=stride, oup=hidden_dim,
relu=False, kernel_size=kernel_size,
name=name+"_dw2", stride=stride,
data_format="NCHW") relu=False,
name=name + "_dw2",
data_format="NCHW")
if use_se: if use_se:
x = self.SElayer(input=x, x = self.SElayer(
num_channels=hidden_dim, input=x, num_channels=hidden_dim, name=name + "SElayer")
name=name+"SElayer") x = self.GhostModule(
x = self.GhostModule(inp=x, inp=x,
oup=oup, oup=oup,
kernel_size=1, kernel_size=1,
relu=False, relu=False,
name=name+"GhostModule_2") name=name + "GhostModule_2")
if stride==1 and inp_channels==oup: if stride == 1 and inp_channels == oup:
shortcut = inp shortcut = inp
else: else:
shortcut = self.depthwise_conv(inp=inp, shortcut = self.depthwise_conv(
oup=inp_channels, inp=inp,
kernel_size=kernel_size, oup=inp_channels,
stride=stride, kernel_size=kernel_size,
relu=False, stride=stride,
name=name+"shortcut_depthwise_conv", relu=False,
data_format="NCHW") name=name + "shortcut_depthwise_conv",
shortcut = self.conv_bn_layer(input=shortcut, data_format="NCHW")
num_filters=oup, shortcut = self.conv_bn_layer(
filter_size=1, input=shortcut,
stride=1, num_filters=oup,
groups=1, filter_size=1,
act=None, stride=1,
name=name+"shortcut_conv_bn", groups=1,
data_format="NCHW") act=None,
return fluid.layers.elementwise_add(x=x, name=name + "shortcut_conv_bn",
y=shortcut, data_format="NCHW")
axis=-1, return fluid.layers.elementwise_add(
act=None, x=x, y=shortcut, axis=-1, act=None, name=name + "elementwise_add")
name=name+"elementwise_add")
def net(self, input, class_dim=1000):
def net(self, # build first layer:
input, output_channel = int(self._make_divisible(16 * self.width_mult, 4))
class_dim=1000): # print(output_channel)
#build first layer: x = self.conv_bn_layer(
output_channel = int(self._make_divisible(16*self.width_mult, 4)) input=input,
#print(output_channel) num_filters=output_channel,
x = self.conv_bn_layer(input=input, filter_size=3,
num_filters=output_channel, stride=2,
filter_size=3, groups=1,
stride=2, act="relu",
groups=1, name="firstlayer",
act="relu", data_format="NCHW")
name="firstlayer", # build inverted residual blocks
data_format="NCHW")
input_channel = output_channel
#build inverted residual blocks
idx = 0 idx = 0
fm = {}
for k, exp_size, c, use_se, s in self.cfgs: for k, exp_size, c, use_se, s in self.cfgs:
output_channel = int(self._make_divisible(c*self.width_mult, 4)) output_channel = int(self._make_divisible(c * self.width_mult, 4))
hidden_channel = int(self._make_divisible(exp_size*self.width_mult, 4)) hidden_channel = int(
x = self.GhostBottleneck(inp=x, self._make_divisible(exp_size * self.width_mult, 4))
hidden_dim=hidden_channel, x = self.GhostBottleneck(
oup=output_channel, inp=x,
kernel_size=k, hidden_dim=hidden_channel,
stride=s, oup=output_channel,
use_se=use_se, kernel_size=k,
name="GhostBottle_"+str(idx), stride=s,
data_format="NCHW") use_se=use_se,
input_channel = output_channel name="GhostBottle_" + str(idx),
fm[str(idx)] = x data_format="NCHW")
idx+=1 idx += 1
#build last several layers # build last several layers
output_channel = int(self._make_divisible(exp_size * self.width_mult, 4)) output_channel = int(
x = self.conv_bn_layer(input=x, self._make_divisible(exp_size * self.width_mult, 4))
num_filters=output_channel, x = self.conv_bn_layer(
filter_size=1, input=x,
stride=1, num_filters=output_channel,
groups=1, filter_size=1,
act="relu", stride=1,
name="lastlayer", groups=1,
data_format="NCHW") act="relu",
x = fluid.layers.pool2d(input=x, name="lastlayer",
pool_type='avg', data_format="NCHW")
global_pooling=True, x = fluid.layers.pool2d(
data_format="NCHW") input=x, pool_type='avg', global_pooling=True, data_format="NCHW")
input_channel = output_channel
output_channel = 1280 output_channel = 1280
stdv = 1.0/math.sqrt(x.shape[1]*1.0) stdv = 1.0 / math.sqrt(x.shape[1] * 1.0)
out = fluid.layers.conv2d(input=x, out = fluid.layers.conv2d(
num_filters=output_channel, input=x,
filter_size=1, num_filters=output_channel,
groups=1, filter_size=1,
param_attr=ParamAttr(name="fc_0_w", initializer=fluid.initializer.Uniform(-stdv, stdv)), groups=1,
bias_attr=False, param_attr=ParamAttr(
name="fc_0") name="fc_0_w",
out = fluid.layers.batch_norm(input=out, initializer=fluid.initializer.Uniform(-stdv, stdv)),
act="relu", bias_attr=False,
name="fc_0_bn", name="fc_0")
param_attr=ParamAttr(name="fc_0_bn_scale", regularizer=fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)), out = fluid.layers.batch_norm(
bias_attr=ParamAttr(name="fc_0_bn_offset", regularizer=fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)), input=out,
moving_mean_name="fc_0_bn_mean", act="relu",
moving_variance_name="fc_0_bn_variance", name="fc_0_bn",
data_layout="NCHW") param_attr=ParamAttr(
name="fc_0_bn_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name="fc_0_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name="fc_0_bn_mean",
moving_variance_name="fc_0_bn_variance",
data_layout="NCHW")
out = fluid.layers.dropout(x=out, dropout_prob=0.2) out = fluid.layers.dropout(x=out, dropout_prob=0.2)
stdv = 1.0/math.sqrt(out.shape[1]*1.0) stdv = 1.0 / math.sqrt(out.shape[1] * 1.0)
out = fluid.layers.fc(input=out, out = fluid.layers.fc(
size=class_dim, input=out,
param_attr=ParamAttr(name="fc_1_w", initializer=fluid.initializer.Uniform(-stdv, stdv)), size=class_dim,
bias_attr=ParamAttr(name="fc_1_bias")) param_attr=ParamAttr(
name="fc_1_w",
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_1_bias"))
return out
return out, fm
def GhostNet_0_5(): def GhostNet_0_5():
model = GhostNet(width_mult=0.5) model = GhostNet(width_mult=0.5)
return model return model
def GhostNet_1_0(): def GhostNet_1_0():
model = GhostNet(width_mult=1.0) model = GhostNet(width_mult=1.0)
return model return model
def GhostNet_1_3(): def GhostNet_1_3():
model = GhostNet(width_mult=1.3) model = GhostNet(width_mult=1.3)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册