提交 7c9e695f 编写于 作者: W weishengyu

change paddle version to 2.0; modify code

上级 ff19b9cf
...@@ -16,212 +16,202 @@ from __future__ import absolute_import ...@@ -16,212 +16,202 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np from paddle import ParamAttr, reshape, transpose, concat, split
import paddle from paddle.nn import Layer, Conv2d, MaxPool2d, AdaptiveAvgPool2d, BatchNorm, Linear
import paddle.fluid as fluid from paddle.nn.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr from paddle.nn.functional import relu, swish
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from paddle.fluid.initializer import MSRA
import math
__all__ = [ __all__ = [
"ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5",
"ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0", "ShuffleNetV2", "ShuffleNetV2_x1_5", "ShuffleNetV2_x2_0",
"ShuffleNetV2_swish" "ShuffleNetV2_swish"
] ]
def channel_shuffle(x, groups): def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[ batch_size, num_channels, height, width = x.shape[0:4]
2], x.shape[3]
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
x = fluid.layers.reshape( x = reshape(x=x, shape=[batch_size, groups, channels_per_group, height, width])
x=x, shape=[batchsize, groups, channels_per_group, height, width])
# transpose
x = transpose(x=x, perm=[0, 2, 1, 3, 4])
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten # flatten
x = fluid.layers.reshape( x = reshape(x=x, shape=[batch_size, num_channels, height, width])
x=x, shape=[batchsize, num_channels, height, width])
return x return x
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(Layer):
def __init__(self, def __init__(
num_channels, self,
filter_size, in_channels,
num_filters, out_channels,
stride, kernel_size,
padding, stride,
channels=None, padding,
num_groups=1, groups=1,
if_act=True, act=relu,
act='relu', name=None,
name=None, ):
use_cudnn=True):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._if_act = if_act
assert act in ['relu', 'swish'], \
"supported act are {} but your act is {}".format(
['relu', 'swish'], act)
self._act = act self._act = act
self._conv = Conv2D( self._conv = Conv2d(
num_channels=num_channels, in_channels=in_channels,
num_filters=num_filters, out_channels=out_channels,
filter_size=filter_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=num_groups, groups=groups,
act=None, weight_attr=ParamAttr(initializer=MSRA(), name=name + "_weights"),
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
bias_attr=False) bias_attr=False)
self._batch_norm = BatchNorm( self._batch_norm = BatchNorm(
num_filters, out_channels,
param_attr=ParamAttr(name=name + "_bn_scale"), param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"), bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean", moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance") moving_variance_name=name + "_bn_variance"
)
def forward(self, inputs, if_act=True): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
if self._if_act: if self._act:
y = fluid.layers.relu( y = self._act(y)
y) if self._act == 'relu' else fluid.layers.swish(y)
return y return y
class InvertedResidualUnit(fluid.dygraph.Layer): class InvertedResidual(Layer):
def __init__(self, def __init__(self,
num_channels, in_channels,
num_filters, out_channels,
stride, stride,
benchmodel, act=relu,
act='relu',
name=None): name=None):
super(InvertedResidualUnit, self).__init__() super(InvertedResidual, self).__init__()
assert stride in [1, 2], \ self._conv_pw = ConvBNLayer(
"supported stride are {} but your stride is {}".format([ in_channels=in_channels // 2,
1, 2], stride) out_channels=out_channels // 2,
self.benchmodel = benchmodel kernel_size=1,
oup_inc = num_filters // 2 stride=1,
inp = num_channels padding=0,
if benchmodel == 1: groups=1,
self._conv_pw = ConvBNLayer( act=act,
num_channels=num_channels // 2, name='stage_' + name + '_conv1'
num_filters=oup_inc, )
filter_size=1, self._conv_dw = ConvBNLayer(
stride=1, in_channels=out_channels // 2,
padding=0, out_channels=out_channels // 2,
num_groups=1, kernel_size=3,
if_act=True, stride=stride,
act=act, padding=1,
name='stage_' + name + '_conv1') groups=out_channels // 2,
self._conv_dw = ConvBNLayer( act=None,
num_channels=oup_inc, name='stage_' + name + '_conv2'
num_filters=oup_inc, )
filter_size=3, self._conv_linear = ConvBNLayer(
stride=stride, in_channels=out_channels // 2,
padding=1, out_channels=out_channels // 2,
num_groups=oup_inc, kernel_size=1,
if_act=False, stride=1,
act=act, padding=0,
use_cudnn=False, groups=1,
name='stage_' + name + '_conv2') act=act,
self._conv_linear = ConvBNLayer( name='stage_' + name + '_conv3'
num_channels=oup_inc, )
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
else:
# branch1
self._conv_dw_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=inp,
filter_size=3,
stride=stride,
padding=1,
num_groups=inp,
if_act=False,
act=act,
use_cudnn=False,
name='stage_' + name + '_conv4')
self._conv_linear_1 = ConvBNLayer(
num_channels=inp,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv5')
# branch2
self._conv_pw_2 = ConvBNLayer(
num_channels=num_channels,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw_2 = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
if_act=False,
act=act,
use_cudnn=False,
name='stage_' + name + '_conv2')
self._conv_linear_2 = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
def forward(self, inputs): def forward(self, inputs):
if self.benchmodel == 1: x1, x2 = split(
x1, x2 = fluid.layers.split( inputs,
inputs, num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], axis=1)
dim=1) x2 = self._conv_pw(x2)
x2 = self._conv_pw(x2) x2 = self._conv_dw(x2)
x2 = self._conv_dw(x2) x2 = self._conv_linear(x2)
x2 = self._conv_linear(x2) out = concat([x1, x2], axis=1)
out = fluid.layers.concat([x1, x2], axis=1)
else:
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x2 = self._conv_pw_2(inputs) return channel_shuffle(out, 2)
x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2)
out = fluid.layers.concat([x1, x2], axis=1) class InvertedResidualDS(Layer):
def __init__(self,
in_channels,
out_channels,
stride,
act=relu,
name=None):
super(InvertedResidualDS, self).__init__()
# branch1
self._conv_dw_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
act=None,
name='stage_' + name + '_conv4'
)
self._conv_linear_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv5'
)
# branch2
self._conv_pw_2 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv1'
)
self._conv_dw_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=out_channels // 2,
act=None,
name='stage_' + name + '_conv2'
)
self._conv_linear_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv3'
)
def forward(self, inputs):
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2)
out = concat([x1, x2], axis=1)
return channel_shuffle(out, 2) return channel_shuffle(out, 2)
class ShuffleNet(fluid.dygraph.Layer): class ShuffleNet(Layer):
def __init__(self, class_dim=1000, scale=1.0, act='relu'): def __init__(self, class_dim=1000, scale=1.0, act=relu):
super(ShuffleNet, self).__init__() super(ShuffleNet, self).__init__()
self.scale = scale self.scale = scale
self.class_dim = class_dim self.class_dim = class_dim
...@@ -244,67 +234,59 @@ class ShuffleNet(fluid.dygraph.Layer): ...@@ -244,67 +234,59 @@ class ShuffleNet(fluid.dygraph.Layer):
"] is not implemented!") "] is not implemented!")
# 1. conv1 # 1. conv1
self._conv1 = ConvBNLayer( self._conv1 = ConvBNLayer(
num_channels=3, in_channels=3,
num_filters=stage_out_channels[1], out_channels=stage_out_channels[1],
filter_size=3, kernel_size=3,
stride=2, stride=2,
padding=1, padding=1,
if_act=True,
act=act, act=act,
name='stage1_conv') name='stage1_conv')
self._max_pool = Pool2D( self._max_pool = MaxPool2d(
pool_type='max', pool_size=3, pool_stride=2, pool_padding=1) kernel_size=3,
stride=2,
padding=1
)
# 2. bottleneck sequences # 2. bottleneck sequences
self._block_list = [] self._block_list = []
i = 1 for stage_id, num_repeat in enumerate(stage_repeats):
in_c = int(32 * scale) for i in range(num_repeat):
for idxstage in range(len(stage_repeats)):
numrepeat = stage_repeats[idxstage]
output_channel = stage_out_channels[idxstage + 2]
for i in range(numrepeat):
if i == 0: if i == 0:
block = self.add_sublayer( block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1), name=str(stage_id + 2) + '_' + str(i + 1),
InvertedResidualUnit( sublayer=InvertedResidualDS(
num_channels=stage_out_channels[idxstage + 1], in_channels=stage_out_channels[stage_id + 1],
num_filters=output_channel, out_channels=stage_out_channels[stage_id + 2],
stride=2, stride=2,
benchmodel=2,
act=act, act=act,
name=str(idxstage + 2) + '_' + str(i + 1))) name=str(stage_id + 2) + '_' + str(i + 1)))
self._block_list.append(block)
else: else:
block = self.add_sublayer( block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1), name=str(stage_id + 2) + '_' + str(i + 1),
InvertedResidualUnit( sublayer=InvertedResidual(
num_channels=output_channel, in_channels=stage_out_channels[stage_id + 2],
num_filters=output_channel, out_channels=stage_out_channels[stage_id + 2],
stride=1, stride=1,
benchmodel=1,
act=act, act=act,
name=str(idxstage + 2) + '_' + str(i + 1))) name=str(stage_id + 2) + '_' + str(i + 1)))
self._block_list.append(block) self._block_list.append(block)
# 3. last_conv # 3. last_conv
self._last_conv = ConvBNLayer( self._last_conv = ConvBNLayer(
num_channels=stage_out_channels[-2], in_channels=stage_out_channels[-2],
num_filters=stage_out_channels[-1], out_channels=stage_out_channels[-1],
filter_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
if_act=True,
act=act, act=act,
name='conv5') name='conv5')
# 4. pool # 4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True) self._pool2d_avg = AdaptiveAvgPool2d(1)
self._out_c = stage_out_channels[-1] self._out_c = stage_out_channels[-1]
# 5. fc # 5. fc
self._fc = Linear( self._fc = Linear(
stage_out_channels[-1], stage_out_channels[-1],
class_dim, class_dim,
param_attr=ParamAttr(name='fc6_weights'), weight_attr=ParamAttr(name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset')) bias_attr=ParamAttr(name='fc6_offset'))
def forward(self, inputs): def forward(self, inputs):
...@@ -314,7 +296,7 @@ class ShuffleNet(fluid.dygraph.Layer): ...@@ -314,7 +296,7 @@ class ShuffleNet(fluid.dygraph.Layer):
y = inv(y) y = inv(y)
y = self._last_conv(y) y = self._last_conv(y)
y = self._pool2d_avg(y) y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c]) y = reshape(y, shape=[-1, self._out_c])
y = self._fc(y) y = self._fc(y)
return y return y
...@@ -350,5 +332,5 @@ def ShuffleNetV2_x2_0(**args): ...@@ -350,5 +332,5 @@ def ShuffleNetV2_x2_0(**args):
def ShuffleNetV2_swish(**args): def ShuffleNetV2_swish(**args):
model = ShuffleNet(scale=1.0, act='swish', **args) model = ShuffleNet(scale=1.0, act=swish, **args)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册