未验证 提交 8e135a4d 编写于 作者: H haoyuying 提交者: GitHub

add mobilenet series

上级 bd8e0ad0
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
from collections import OrderedDict
import numpy as np
import cv2
from PIL import Image, ImageEnhance
from paddle import fluid
DATA_DIM = 224
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img
def test_reader(paths=None, images=None):
"""data generator
:param paths: path to images.
:type paths: list, each element is a str
:param images: data of images, [N, H, W, C]
:type images: numpy.ndarray
"""
img_list = []
if paths:
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file path.".format(img_path)
img = Image.open(img_path)
#img = cv2.imread(img_path)
img_list.append(img)
if images is not None:
for img in images:
img_list.append(Image.fromarray(np.uint8(img)))
for im in img_list:
im = process_image(im)
yield im
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['MobileNet']
class MobileNet(object):
"""
MobileNet v1, see https://arxiv.org/abs/1704.04861
Args:
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
conv_group_scale (int): scaling factor for convolution groups
with_extra_blocks (bool): if extra blocks should be added
extra_block_filters (list): number of filter for each extra block
class_dim (int): number of class while classification
yolo_v3 (bool): whether to output layers which yolo_v3 needs
"""
__shared__ = ['norm_type', 'weight_prefix_name']
def __init__(self,
norm_type='bn',
norm_decay=0.,
conv_group_scale=1,
conv_learning_rate=1.0,
with_extra_blocks=False,
extra_block_filters=[[256, 512], [128, 256], [128, 256],
[64, 128]],
weight_prefix_name='',
class_dim=1000,
yolo_v3=False):
self.norm_type = norm_type
self.norm_decay = norm_decay
self.conv_group_scale = conv_group_scale
self.conv_learning_rate = conv_learning_rate
self.with_extra_blocks = with_extra_blocks
self.extra_block_filters = extra_block_filters
self.prefix_name = weight_prefix_name
self.class_dim = class_dim
self.yolo_v3 = yolo_v3
def _conv_norm(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
parameter_attr = ParamAttr(
learning_rate=self.conv_learning_rate,
initializer=fluid.initializer.MSRA(),
name=name + "_weights")
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
bn_name = name + "_bn"
norm_decay = self.norm_decay
bn_param_attr = ParamAttr(
regularizer=L2Decay(norm_decay), name=bn_name + '_scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(norm_decay), name=bn_name + '_offset')
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def depthwise_separable(self,
input,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
depthwise_conv = self._conv_norm(
input=input,
filter_size=3,
num_filters=int(num_filters1 * scale),
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False,
name=name + "_dw")
pointwise_conv = self._conv_norm(
input=depthwise_conv,
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0,
name=name + "_sep")
return pointwise_conv
def _extra_block(self,
input,
num_filters1,
num_filters2,
num_groups,
stride,
name=None):
pointwise_conv = self._conv_norm(
input=input,
filter_size=1,
num_filters=int(num_filters1),
stride=1,
num_groups=int(num_groups),
padding=0,
name=name + "_extra1")
normal_conv = self._conv_norm(
input=pointwise_conv,
filter_size=3,
num_filters=int(num_filters2),
stride=2,
num_groups=int(num_groups),
padding=1,
name=name + "_extra2")
return normal_conv
def __call__(self, input):
scale = self.conv_group_scale
blocks = []
# input 1/1
out = self._conv_norm(
input, 3, int(32 * scale), 2, 1, name=self.prefix_name + "conv1")
# 1/2
out = self.depthwise_separable(
out, 32, 64, 32, 1, scale, name=self.prefix_name + "conv2_1")
out = self.depthwise_separable(
out, 64, 128, 64, 2, scale, name=self.prefix_name + "conv2_2")
# 1/4
out = self.depthwise_separable(
out, 128, 128, 128, 1, scale, name=self.prefix_name + "conv3_1")
out = self.depthwise_separable(
out, 128, 256, 128, 2, scale, name=self.prefix_name + "conv3_2")
# 1/8
blocks.append(out)
out = self.depthwise_separable(
out, 256, 256, 256, 1, scale, name=self.prefix_name + "conv4_1")
out = self.depthwise_separable(
out, 256, 512, 256, 2, scale, name=self.prefix_name + "conv4_2")
# 1/16
blocks.append(out)
for i in range(5):
out = self.depthwise_separable(
out,
512,
512,
512,
1,
scale,
name=self.prefix_name + "conv5_" + str(i + 1))
module11 = out
out = self.depthwise_separable(
out, 512, 1024, 512, 2, scale, name=self.prefix_name + "conv5_6")
# 1/32
out = self.depthwise_separable(
out, 1024, 1024, 1024, 1, scale, name=self.prefix_name + "conv6")
module13 = out
blocks.append(out)
if self.yolo_v3:
return blocks
if not self.with_extra_blocks:
out = fluid.layers.pool2d(
input=out, pool_type='avg', global_pooling=True)
out = fluid.layers.fc(
input=out,
size=self.class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
out = fluid.layers.softmax(out)
blocks.append(out)
return blocks
num_filters = self.extra_block_filters
module14 = self._extra_block(module13, num_filters[0][0],
num_filters[0][1], 1, 2,
self.prefix_name + "conv7_1")
module15 = self._extra_block(module14, num_filters[1][0],
num_filters[1][1], 1, 2,
self.prefix_name + "conv7_2")
module16 = self._extra_block(module15, num_filters[2][0],
num_filters[2][1], 1, 2,
self.prefix_name + "conv7_3")
module17 = self._extra_block(module16, num_filters[3][0],
num_filters[3][1], 1, 2,
self.prefix_name + "conv7_4")
return module11, module13, module14, module15, module16, module17
# coding=utf-8
def load_label_info(file_path):
with open(file_path, 'r') as fr:
return fr.read().split("\n")[:-1]
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import MSRA
from paddlehub.module.module import moduleinfo
from paddlehub.module.cv_module import ImageClassifierModule
class ConvBNLayer(nn.Layer):
"""Basic conv bn layer."""
def __init__(self,
num_channels: int,
filter_size: int,
num_filters: int,
stride: int,
padding: int,
channels: int = None,
num_groups: int = 1,
act: str = 'relu',
name: str = None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2d(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=MSRA(), name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(num_filters,
act=act,
param_attr=ParamAttr(name + "_bn_scale"),
bias_attr=ParamAttr(name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, inputs: paddle.Tensor):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(nn.Layer):
"""Depthwise and pointwise conv layer."""
def __init__(self,
num_channels: int,
num_filters1: int,
num_filters2: int,
num_groups: int,
stride: int,
scale: float,
name: str = None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
name=name + "_dw")
self._pointwise_conv = ConvBNLayer(num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0,
name=name + "_sep")
def forward(self, inputs: paddle.Tensor):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
@moduleinfo(name="mobilenet_v1_imagenet_ssld",
type="cv/classification",
author="paddlepaddle",
author_email="",
summary="mobilenet_v1_imagenet_ssld is a classification model, "
"this module is trained with Imagenet dataset.",
version="1.1.0",
meta=ImageClassifierModule)
class MobileNetV1(nn.Layer):
"""MobileNetV1"""
def __init__(self, class_dim: int = 1000, load_checkpoint: str = None):
super(MobileNetV1, self).__init__()
self.block_list = []
self.conv1 = ConvBNLayer(num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32),
stride=2,
padding=1,
name="conv1")
conv2_1 = self.add_sublayer("conv2_1",
sublayer=DepthwiseSeparable(num_channels=int(32),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=1,
name="conv2_1"))
self.block_list.append(conv2_1)
conv2_2 = self.add_sublayer("conv2_2",
sublayer=DepthwiseSeparable(num_channels=int(64),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=1,
name="conv2_2"))
self.block_list.append(conv2_2)
conv3_1 = self.add_sublayer("conv3_1",
sublayer=DepthwiseSeparable(num_channels=int(128),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=1,
name="conv3_1"))
self.block_list.append(conv3_1)
conv3_2 = self.add_sublayer("conv3_2",
sublayer=DepthwiseSeparable(num_channels=int(128),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=1,
name="conv3_2"))
self.block_list.append(conv3_2)
conv4_1 = self.add_sublayer("conv4_1",
sublayer=DepthwiseSeparable(num_channels=int(256),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=1,
name="conv4_1"))
self.block_list.append(conv4_1)
conv4_2 = self.add_sublayer("conv4_2",
sublayer=DepthwiseSeparable(num_channels=int(256),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=1,
name="conv4_2"))
self.block_list.append(conv4_2)
for i in range(5):
conv5 = self.add_sublayer("conv5_" + str(i + 1),
sublayer=DepthwiseSeparable(num_channels=int(512),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=1,
name="conv5_" + str(i + 1)))
self.block_list.append(conv5)
conv5_6 = self.add_sublayer("conv5_6",
sublayer=DepthwiseSeparable(num_channels=int(512),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=1,
name="conv5_6"))
self.block_list.append(conv5_6)
conv6 = self.add_sublayer("conv6",
sublayer=DepthwiseSeparable(num_channels=int(1024),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=1,
name="conv6"))
self.block_list.append(conv6)
self.pool2d_avg = AdaptiveAvgPool2d(1)
self.out = Linear(int(1024),
class_dim,
weight_attr=ParamAttr(initializer=MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'mobilenet_v1_ssld_imagenet.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/mobilenet_v1_ssld_imagenet.pdparams -O '
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def forward(self, inputs: paddle.Tensor):
y = self.conv1(inputs)
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddlehub.module.module import moduleinfo
from paddlehub.module.cv_module import ImageClassifierModule
class ConvBNLayer(nn.Layer):
"""Basic conv bn layer."""
def __init__(self,
num_channels: int,
filter_size: int,
num_filters: int,
stride: int,
padding: int,
num_groups: int = 1,
name: str = None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2d(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(num_filters,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, inputs: paddle.Tensor, if_act: bool = True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = F.relu6(y)
return y
class InvertedResidualUnit(nn.Layer):
"""Inverted Residual unit."""
def __init__(self, num_channels: int, num_in_filter: int, num_filters: int, stride: int, filter_size: int,
padding: int, expansion_factor: int, name: str):
super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(num_channels=num_channels,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
name=name + "_expand")
self._bottleneck_conv = ConvBNLayer(num_channels=num_expfilter,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
name=name + "_dwise")
self._linear_conv = ConvBNLayer(num_channels=num_expfilter,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
name=name + "_linear")
def forward(self, inputs: paddle.Tensor, ifshortcut: bool):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = paddle.elementwise_add(inputs, y)
return y
class InversiBlocks(nn.Layer):
"""Inverted residual block composed by inverted residual unit."""
def __init__(self, in_c: int, t: int, c: int, n: int, s: int, name: str):
super(InversiBlocks, self).__init__()
self._first_block = InvertedResidualUnit(num_channels=in_c,
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + "_1")
self._block_list = []
for i in range(1, n):
block = self.add_sublayer(name + "_" + str(i + 1),
sublayer=InvertedResidualUnit(num_channels=c,
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + "_" + str(i + 1)))
self._block_list.append(block)
def forward(self, inputs: paddle.Tensor):
y = self._first_block(inputs, ifshortcut=False)
for block in self._block_list:
y = block(y, ifshortcut=True)
return y
@moduleinfo(name="mobilenet_v2_imagenet",
type="cv/classification",
author="paddlepaddle",
author_email="",
summary="mobilenet_v2_imagenet is a classification model, "
"this module is trained with Imagenet dataset.",
version="1.1.0",
meta=ImageClassifierModule)
class MobileNet(nn.Layer):
"""MobileNetV2"""
def __init__(self, class_dim: int = 1000, load_checkpoint: str = None):
super(MobileNet, self).__init__()
self.class_dim = class_dim
bottleneck_params_list = [(1, 16, 1, 1), (6, 24, 2, 2), (6, 32, 3, 2), (6, 64, 4, 2), (6, 96, 3, 1),
(6, 160, 3, 2), (6, 320, 1, 1)]
self.conv1 = ConvBNLayer(num_channels=3,
num_filters=int(32),
filter_size=3,
stride=2,
padding=1,
name="conv1_1")
self.block_list = []
i = 1
in_c = int(32)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
block = self.add_sublayer("conv" + str(i),
sublayer=InversiBlocks(in_c=in_c, t=t, c=int(c), n=n, s=s, name="conv" + str(i)))
self.block_list.append(block)
in_c = int(c)
self.out_c = 1280
self.conv9 = ConvBNLayer(num_channels=in_c,
num_filters=self.out_c,
filter_size=1,
stride=1,
padding=0,
name="conv9")
self.pool2d_avg = AdaptiveAvgPool2d(1)
self.out = Linear(self.out_c,
class_dim,
weight_attr=ParamAttr(name="fc10_weights"),
bias_attr=ParamAttr(name="fc10_offset"))
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'mobilenet_v2_imagenet.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/mobilenet_v2_imagenet.pdparams -O '
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def forward(self, inputs: paddle.Tensor):
y = self.conv1(inputs, if_act=True)
for block in self.block_list:
y = block(y)
y = self.conv9(y, if_act=True)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.out_c])
y = self.out(y)
return y
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import MSRA
from paddlehub.module.module import moduleinfo
from paddlehub.module.cv_module import ImageClassifierModule
def channel_shuffle(x: paddle.Tensor, groups: int):
"""Shuffle input channels."""
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
channels_per_group = num_channels // groups
# reshape
x = paddle.reshape(x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten
x = paddle.reshape(x=x, shape=[batchsize, num_channels, height, width])
return x
class ConvBNLayer(nn.Layer):
"""Basic conv bn layer."""
def __init__(self,
num_channels: int,
filter_size: int,
num_filters: int,
stride: int,
padding: int,
channels: int = None,
num_groups: int = 1,
if_act: bool = True,
act: str = 'relu',
name: str = None):
super(ConvBNLayer, self).__init__()
self._if_act = if_act
assert act in ['relu', 'swish'], \
"supported act are {} but your act is {}".format(
['relu', 'swish'], act)
self._act = act
self._conv = Conv2d(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(initializer=MSRA(), name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(num_filters,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, inputs: paddle.Tensor, if_act: bool = True):
y = self._conv(inputs)
y = self._batch_norm(y)
if self._if_act:
y = F.relu(y) if self._act == 'relu' else F.swish(y)
return y
class InvertedResidualUnit(nn.Layer):
"""Inverted Residual unit."""
def __init__(self,
num_channels: int,
num_filters: int,
stride: int,
benchmodel: int,
act: str = 'relu',
name: str = None):
super(InvertedResidualUnit, self).__init__()
assert stride in [1, 2], \
"supported stride are {} but your stride is {}".format([1, 2], stride)
self.benchmodel = benchmodel
oup_inc = num_filters // 2
inp = num_channels
if benchmodel == 1:
self._conv_pw = ConvBNLayer(num_channels=num_channels // 2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw = ConvBNLayer(num_channels=oup_inc,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
if_act=False,
act=act,
name='stage_' + name + '_conv2')
self._conv_linear = ConvBNLayer(num_channels=oup_inc,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
else:
# branch1
self._conv_dw_1 = ConvBNLayer(num_channels=num_channels,
num_filters=inp,
filter_size=3,
stride=stride,
padding=1,
num_groups=inp,
if_act=False,
act=act,
name='stage_' + name + '_conv4')
self._conv_linear_1 = ConvBNLayer(num_channels=inp,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv5')
# branch2
self._conv_pw_2 = ConvBNLayer(num_channels=num_channels,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv1')
self._conv_dw_2 = ConvBNLayer(num_channels=oup_inc,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
if_act=False,
act=act,
name='stage_' + name + '_conv2')
self._conv_linear_2 = ConvBNLayer(num_channels=oup_inc,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act=act,
name='stage_' + name + '_conv3')
def forward(self, inputs: paddle.Tensor):
if self.benchmodel == 1:
x1, x2 = paddle.split(inputs, num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], axis=1)
x2 = self._conv_pw(x2)
x2 = self._conv_dw(x2)
x2 = self._conv_linear(x2)
out = paddle.concat([x1, x2], axis=1)
else:
x1 = self._conv_dw_1(inputs)
x1 = self._conv_linear_1(x1)
x2 = self._conv_pw_2(inputs)
x2 = self._conv_dw_2(x2)
x2 = self._conv_linear_2(x2)
out = paddle.concat([x1, x2], axis=1)
return channel_shuffle(out, 2)
@moduleinfo(name="shufflenet_v2_imagenet",
type="cv/classification",
author="paddlepaddle",
author_email="",
summary="shufflenet_v2_imagenet is a classification model, "
"this module is trained with Imagenet dataset.",
version="1.1.0",
meta=ImageClassifierModule)
class ShuffleNet(nn.Layer):
"""ShuffleNet model."""
def __init__(self, class_dim: int = 1000, load_checkpoint: str = None):
super(ShuffleNet, self).__init__()
self.scale = 1
self.class_dim = class_dim
stage_repeats = [4, 8, 4]
stage_out_channels = [-1, 24, 116, 232, 464, 1024]
# 1. conv1
self._conv1 = ConvBNLayer(num_channels=3,
num_filters=stage_out_channels[1],
filter_size=3,
stride=2,
padding=1,
if_act=True,
act='relu',
name='stage1_conv')
self._max_pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
# 2. bottleneck sequences
self._block_list = []
i = 1
in_c = int(32)
for idxstage in range(len(stage_repeats)):
numrepeat = stage_repeats[idxstage]
output_channel = stage_out_channels[idxstage + 2]
for i in range(numrepeat):
if i == 0:
block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1),
InvertedResidualUnit(num_channels=stage_out_channels[idxstage + 1],
num_filters=output_channel,
stride=2,
benchmodel=2,
act='relu',
name=str(idxstage + 2) + '_' + str(i + 1)))
self._block_list.append(block)
else:
block = self.add_sublayer(
str(idxstage + 2) + '_' + str(i + 1),
InvertedResidualUnit(num_channels=output_channel,
num_filters=output_channel,
stride=1,
benchmodel=1,
act='relu',
name=str(idxstage + 2) + '_' + str(i + 1)))
self._block_list.append(block)
# 3. last_conv
self._last_conv = ConvBNLayer(num_channels=stage_out_channels[-2],
num_filters=stage_out_channels[-1],
filter_size=1,
stride=1,
padding=0,
if_act=True,
act='relu',
name='conv5')
# 4. pool
self._pool2d_avg = AdaptiveAvgPool2d(1)
self._out_c = stage_out_channels[-1]
# 5. fc
self._fc = Linear(stage_out_channels[-1],
class_dim,
weight_attr=ParamAttr(name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'shufflenet_v2_imagenet.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/image_classification/shufflenet_v2_imagenet.pdparams -O '
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def forward(self, inputs: paddle.Tensor):
y = self._conv1(inputs)
y = self._max_pool(y)
for inv in self._block_list:
y = inv(y)
y = self._last_conv(y)
y = self._pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册