未验证 提交 fa6ce23f 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #331 from weisy11/dygraph

add ghostnet and modify shufflenet
mode: 'train'
ARCHITECTURE:
name: 'GhostNet_x0_5'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 360
topk: 5
image_shape: [3, 224, 224]
use_mix: False
ls_epsilon: 0.1
LEARNING_RATE:
function: 'CosineWarmup'
params:
lr: 0.8
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.0000400
TRAIN:
batch_size: 2048
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mode: 'train'
ARCHITECTURE:
name: 'GhostNet_x1_0'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 360
topk: 5
image_shape: [3, 224, 224]
use_mix: False
ls_epsilon: 0.1
LEARNING_RATE:
function: 'CosineWarmup'
params:
lr: 0.4
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.0000400
TRAIN:
batch_size: 1024
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
mode: 'train'
ARCHITECTURE:
name: 'GhostNet_x1_3'
pretrained_model: ""
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
save_interval: 1
validate: True
valid_interval: 1
epochs: 360
topk: 5
image_shape: [3, 224, 224]
use_mix: False
ls_epsilon: 0.1
LEARNING_RATE:
function: 'CosineWarmup'
params:
lr: 0.4
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.0000400
TRAIN:
batch_size: 1024
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- AutoAugment:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 64
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.5
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.5
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.5
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.5
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.5
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.25
warmup_epoch: 5
......
......@@ -14,7 +14,7 @@ topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
function: 'CosineWarmup'
params:
lr: 0.25
warmup_epoch: 5
......
......@@ -28,6 +28,7 @@ from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50
from .googlenet import GoogLeNet
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75, MobileNetV1
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, AdaptiveAvgPool2d, Linear
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle.nn.initializer import Uniform
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
act="relu",
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(
initializer=nn.initializer.MSRA(), name=name + "_weights"),
bias_attr=False)
bn_name = name + "_bn"
# In the old version, moving_variance_name was name + "_variance"
self._batch_norm = BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=L2DecayRegularizer(regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=L2DecayRegularizer(regularization_coeff=0.0)),
moving_mean_name=bn_name + "_mean",
moving_variance_name=name +
"_variance" # wrong due to an old typo, will be fixed later.
)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class SEBlock(nn.Layer):
def __init__(self, num_channels, reduction_ratio=4, name=None):
super(SEBlock, self).__init__()
self.pool2d_gap = AdaptiveAvgPool2d(1)
self._num_channels = num_channels
stdv = 1.0 / math.sqrt(num_channels * 1.0)
med_ch = num_channels // reduction_ratio
self.squeeze = Linear(
num_channels,
med_ch,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear(
med_ch,
num_channels,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name=name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
pool = self.pool2d_gap(inputs)
pool = paddle.reshape(pool, shape=[-1, self._num_channels])
squeeze = self.squeeze(pool)
squeeze = F.relu(squeeze)
excitation = self.excitation(squeeze)
excitation = paddle.fluid.layers.clip(x=excitation, min=0, max=1)
excitation = paddle.reshape(
excitation, shape=[-1, self._num_channels, 1, 1])
out = inputs * excitation
return out
class GhostModule(nn.Layer):
def __init__(self,
in_channels,
output_channels,
kernel_size=1,
ratio=2,
dw_size=3,
stride=1,
relu=True,
name=None):
super(GhostModule, self).__init__()
init_channels = int(math.ceil(output_channels / ratio))
new_channels = int(init_channels * (ratio - 1))
self.primary_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=init_channels,
kernel_size=kernel_size,
stride=stride,
groups=1,
act="relu" if relu else None,
name=name + "_primary_conv")
self.cheap_operation = ConvBNLayer(
in_channels=init_channels,
out_channels=new_channels,
kernel_size=dw_size,
stride=1,
groups=init_channels,
act="relu" if relu else None,
name=name + "_cheap_operation")
def forward(self, inputs):
x = self.primary_conv(inputs)
y = self.cheap_operation(x)
out = paddle.concat([x, y], axis=1)
return out
class GhostBottleneck(nn.Layer):
def __init__(self,
in_channels,
hidden_dim,
output_channels,
kernel_size,
stride,
use_se,
name=None):
super(GhostBottleneck, self).__init__()
self._stride = stride
self._use_se = use_se
self._num_channels = in_channels
self._output_channels = output_channels
self.ghost_module_1 = GhostModule(
in_channels=in_channels,
output_channels=hidden_dim,
kernel_size=1,
stride=1,
relu=True,
name=name + "_ghost_module_1")
if stride == 2:
self.depthwise_conv = ConvBNLayer(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=kernel_size,
stride=stride,
groups=hidden_dim,
act=None,
name=name +
"_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
)
if use_se:
self.se_block = SEBlock(num_channels=hidden_dim, name=name + "_se")
self.ghost_module_2 = GhostModule(
in_channels=hidden_dim,
output_channels=output_channels,
kernel_size=1,
relu=False,
name=name + "_ghost_module_2")
if stride != 1 or in_channels != output_channels:
self.shortcut_depthwise = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
groups=in_channels,
act=None,
name=name +
"_shortcut_depthwise_depthwise" # looks strange due to an old typo, will be fixed later.
)
self.shortcut_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=output_channels,
kernel_size=1,
stride=1,
groups=1,
act=None,
name=name + "_shortcut_conv")
def forward(self, inputs):
x = self.ghost_module_1(inputs)
if self._stride == 2:
x = self.depthwise_conv(x)
if self._use_se:
x = self.se_block(x)
x = self.ghost_module_2(x)
if self._stride == 1 and self._num_channels == self._output_channels:
shortcut = inputs
else:
shortcut = self.shortcut_depthwise(inputs)
shortcut = self.shortcut_conv(shortcut)
return paddle.elementwise_add(x=x, y=shortcut, axis=-1)
class GhostNet(nn.Layer):
def __init__(self, scale, class_dim=1000):
super(GhostNet, self).__init__()
self.cfgs = [
# k, t, c, SE, s
[3, 16, 16, 0, 1],
[3, 48, 24, 0, 2],
[3, 72, 24, 0, 1],
[5, 72, 40, 1, 2],
[5, 120, 40, 1, 1],
[3, 240, 80, 0, 2],
[3, 200, 80, 0, 1],
[3, 184, 80, 0, 1],
[3, 184, 80, 0, 1],
[3, 480, 112, 1, 1],
[3, 672, 112, 1, 1],
[5, 672, 160, 1, 2],
[5, 960, 160, 0, 1],
[5, 960, 160, 1, 1],
[5, 960, 160, 0, 1],
[5, 960, 160, 1, 1]
]
self.scale = scale
output_channels = int(self._make_divisible(16 * self.scale, 4))
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=output_channels,
kernel_size=3,
stride=2,
groups=1,
act="relu",
name="conv1")
# build inverted residual blocks
idx = 0
self.ghost_bottleneck_list = []
for k, exp_size, c, use_se, s in self.cfgs:
in_channels = output_channels
output_channels = int(self._make_divisible(c * self.scale, 4))
hidden_dim = int(self._make_divisible(exp_size * self.scale, 4))
ghost_bottleneck = self.add_sublayer(
name="_ghostbottleneck_" + str(idx),
sublayer=GhostBottleneck(
in_channels=in_channels,
hidden_dim=hidden_dim,
output_channels=output_channels,
kernel_size=k,
stride=s,
use_se=use_se,
name="_ghostbottleneck_" + str(idx)))
self.ghost_bottleneck_list.append(ghost_bottleneck)
idx += 1
# build last several layers
in_channels = output_channels
output_channels = int(self._make_divisible(exp_size * self.scale, 4))
self.conv_last = ConvBNLayer(
in_channels=in_channels,
out_channels=output_channels,
kernel_size=1,
stride=1,
groups=1,
act="relu",
name="conv_last")
self.pool2d_gap = AdaptiveAvgPool2d(1)
in_channels = output_channels
self._fc0_output_channels = 1280
self.fc_0 = ConvBNLayer(
in_channels=in_channels,
out_channels=self._fc0_output_channels,
kernel_size=1,
stride=1,
act="relu",
name="fc_0")
self.dropout = nn.Dropout(p=0.2)
stdv = 1.0 / math.sqrt(self._fc0_output_channels * 1.0)
self.fc_1 = Linear(
self._fc0_output_channels,
class_dim,
weight_attr=ParamAttr(
name="fc_1_weights", initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_1_offset"))
def forward(self, inputs):
x = self.conv1(inputs)
for ghost_bottleneck in self.ghost_bottleneck_list:
x = ghost_bottleneck(x)
x = self.conv_last(x)
x = self.pool2d_gap(x)
x = self.fc_0(x)
x = self.dropout(x)
x = paddle.reshape(x, shape=[-1, self._fc0_output_channels])
x = self.fc_1(x)
return x
def _make_divisible(self, v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def GhostNet_x0_5(**args):
model = GhostNet(scale=0.5)
return model
def GhostNet_x1_0(**args):
model = GhostNet(scale=1.0)
return model
def GhostNet_x1_3(**args):
model = GhostNet(scale=1.3)
return model
......@@ -16,15 +16,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle import ParamAttr, reshape, transpose, concat, split
from paddle.nn import Layer, Conv2d, MaxPool2d, AdaptiveAvgPool2d, BatchNorm, Linear
from paddle.nn.initializer import MSRA
import math
from paddle.nn.functional import swish
__all__ = [
"ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33", "ShuffleNetV2_x0_5",
......@@ -34,188 +29,176 @@ __all__ = [
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[
2], x.shape[3]
batch_size, num_channels, height, width = x.shape[0:4]
channels_per_group = num_channels // groups
# reshape
x = paddle.reshape(
x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width])
# transpose
x = transpose(x=x, perm=[0, 2, 1, 3, 4])
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten
x = paddle.reshape(x=x, shape=[batchsize, num_channels, height, width])
x = reshape(x=x, shape=[batch_size, num_channels, height, width])
return x
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
if_act=True,
act='relu',
name=None):
class ConvBNLayer(Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self._if_act = if_act
assert act in ['relu', 'swish'], \
"supported act are {} but your act is {}".format(
['relu', 'swish'], act)
self._act = act
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=num_groups,
groups=groups,
weight_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
out_channels,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
act=act,
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, inputs, if_act=True):
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self._if_act:
y = F.relu(y) if self._act == 'relu' else F.swish(y)
return y
class InvertedResidualUnit(nn.Layer):
class InvertedResidual(Layer):
def __init__(self,
num_channels,
num_filters,
in_channels,
out_channels,
stride,
benchmodel,
act='relu',
act="relu",
name=None):
super(InvertedResidualUnit, self).__init__()
assert stride in [1, 2], \
"supported stride are {} but your stride is {}".format([
1, 2], stride)
self.benchmodel = benchmodel
oup_inc = num_filters // 2
inp = num_channels
if benchmodel == 1:
self._conv_pw = ConvBNLayer(
num_channels=num_channels // 2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
if_act=False,
act=act,
name='stage_' + name + '_conv2')
self._conv_linear = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
else:
# branch1
self._conv_dw_1 = ConvBNLayer(
num_channels=num_channels,
num_filters=inp,
filter_size=3,
stride=stride,
padding=1,
num_groups=inp,
if_act=False,
act=act,
name='stage_' + name + '_conv4')
self._conv_linear_1 = ConvBNLayer(
num_channels=inp,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv5')
# branch2
self._conv_pw_2 = ConvBNLayer(
num_channels=num_channels,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw_2 = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
if_act=False,
act=act,
name='stage_' + name + '_conv2')
self._conv_linear_2 = ConvBNLayer(
num_channels=oup_inc,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
super(InvertedResidual, self).__init__()
self._conv_pw = ConvBNLayer(
in_channels=in_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=out_channels // 2,
act=None,
name='stage_' + name + '_conv2')
self._conv_linear = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv3')
def forward(self, inputs):
if self.benchmodel == 1:
x1, x2 = paddle.split(
inputs,
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
axis=1)
x2 = self._conv_pw(x2)
x2 = self._conv_dw(x2)
x2 = self._conv_linear(x2)
out = paddle.concat([x1, x2], axis=1)
else:
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x1, x2 = split(
inputs,
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
axis=1)
x2 = self._conv_pw(x2)
x2 = self._conv_dw(x2)
x2 = self._conv_linear(x2)
out = concat([x1, x2], axis=1)
return channel_shuffle(out, 2)
x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2)
out = paddle.concat([x1, x2], axis=1)
class InvertedResidualDS(Layer):
def __init__(self,
in_channels,
out_channels,
stride,
act="relu",
name=None):
super(InvertedResidualDS, self).__init__()
# branch1
self._conv_dw_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
act=None,
name='stage_' + name + '_conv4')
self._conv_linear_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv5')
# branch2
self._conv_pw_2 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=3,
stride=stride,
padding=1,
groups=out_channels // 2,
act=None,
name='stage_' + name + '_conv2')
self._conv_linear_2 = ConvBNLayer(
in_channels=out_channels // 2,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
groups=1,
act=act,
name='stage_' + name + '_conv3')
def forward(self, inputs):
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2)
out = concat([x1, x2], axis=1)
return channel_shuffle(out, 2)
class ShuffleNet(nn.Layer):
def __init__(self, class_dim=1000, scale=1.0, act='relu'):
class ShuffleNet(Layer):
def __init__(self, class_dim=1000, scale=1.0, act="relu"):
super(ShuffleNet, self).__init__()
self.scale = scale
self.class_dim = class_dim
......@@ -238,58 +221,47 @@ class ShuffleNet(nn.Layer):
"] is not implemented!")
# 1. conv1
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=stage_out_channels[1],
filter_size=3,
in_channels=3,
out_channels=stage_out_channels[1],
kernel_size=3,
stride=2,
padding=1,
if_act=True,
act=act,
name='stage1_conv')
self._max_pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
# 2. bottleneck sequences
self._block_list = []
i = 1
in_c = int(32 * scale)
for idxstage in range(len(stage_repeats)):
numrepeat = stage_repeats[idxstage]
output_channel = stage_out_channels[idxstage + 2]
for i in range(numrepeat):
for stage_id, num_repeat in enumerate(stage_repeats):
for i in range(num_repeat):
if i == 0:
block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1),
InvertedResidualUnit(
num_channels=stage_out_channels[idxstage + 1],
num_filters=output_channel,
name=str(stage_id + 2) + '_' + str(i + 1),
sublayer=InvertedResidualDS(
in_channels=stage_out_channels[stage_id + 1],
out_channels=stage_out_channels[stage_id + 2],
stride=2,
benchmodel=2,
act=act,
name=str(idxstage + 2) + '_' + str(i + 1)))
self._block_list.append(block)
name=str(stage_id + 2) + '_' + str(i + 1)))
else:
block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1),
InvertedResidualUnit(
num_channels=output_channel,
num_filters=output_channel,
name=str(stage_id + 2) + '_' + str(i + 1),
sublayer=InvertedResidual(
in_channels=stage_out_channels[stage_id + 2],
out_channels=stage_out_channels[stage_id + 2],
stride=1,
benchmodel=1,
act=act,
name=str(idxstage + 2) + '_' + str(i + 1)))
self._block_list.append(block)
name=str(stage_id + 2) + '_' + str(i + 1)))
self._block_list.append(block)
# 3. last_conv
self._last_conv = ConvBNLayer(
num_channels=stage_out_channels[-2],
num_filters=stage_out_channels[-1],
filter_size=1,
in_channels=stage_out_channels[-2],
out_channels=stage_out_channels[-1],
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name='conv5')
# 4. pool
self._pool2d_avg = AdaptiveAvgPool2d(1)
self._out_c = stage_out_channels[-1]
......@@ -307,13 +279,13 @@ class ShuffleNet(nn.Layer):
y = inv(y)
y = self._last_conv(y)
y = self._pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self._out_c])
y = reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def ShuffleNetV2_x0_25(**args):
model = ShuffleNetV2(scale=0.25, **args)
model = ShuffleNet(scale=0.25, **args)
return model
......@@ -343,5 +315,5 @@ def ShuffleNetV2_x2_0(**args):
def ShuffleNetV2_swish(**args):
model = ShuffleNet(scale=1.0, act='swish', **args)
model = ShuffleNet(scale=1.0, act="swish", **args)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册