提交 29dd1120 编写于 作者: L LielinJiang

add mobilenetv1

上级 66062cfe
......@@ -77,6 +77,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch
## 参考文献
- ResNet: [Deep Residual Learning for Image Recognitio](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
- MobileNetV1: [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861), Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
- MobileNetV2: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/pdf/1801.04381v4.pdf), Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
- VGG: [Very Deep Convolutional Networks for Large-scale Image Recognition](https://arxiv.org/pdf/1409.1556), Karen Simonyan, Andrew Zisserman
from .resnet import *
from .mobilenet import *
from .mobilenetv1 import *
from .mobilenetv2 import *
from .vgg import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import sys
import math
import numpy as np
import argparse
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
from model import Model
from .download import get_weights_path
__all__ = [
'MobileNetV1', 'mobilnetv1_x0_25', 'mobilnetv1_x0_5', 'mobilnetv1_x0_75',
'mobilnetv1_x1_0', 'mobilnetv1_x1_25', 'mobilnetv1_x1_5',
'mobilnetv1_x1_75', 'mobilnetv1_x2_0'
]
model_urls = {}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(Model):
def __init__(self, scale=1.0, class_dim=1000):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear(
int(1024 * scale),
class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilnetv1_x1_0(pretrained=False):
model = _mobilenet('mobilenetv1_1.0', pretrained, scale=1.0)
return model
def mobilnetv1_x0_25(pretrained=False):
model = _mobilenet('mobilenetv1_0.25', pretrained, scale=0.25)
return model
def mobilnetv1_x0_5(pretrained=False):
model = _mobilenet('mobilenetv1_0.5', pretrained, scale=0.5)
return model
def mobilnetv1_x0_75(pretrained=False):
model = _mobilenet('mobilenetv1_0.75', pretrained, scale=0.75)
return model
def mobilnetv1_x1_25(pretrained=False):
model = _mobilenet('mobilenetv1_1.25', pretrained, scale=1.25)
return model
def mobilnetv1_x1_5(pretrained=False):
model = _mobilenet('mobilenetv1_1.5', pretrained, scale=1.5)
return model
def mobilnetv1_x1_75(pretrained=False):
model = _mobilenet('mobilenetv1_1.75', pretrained, scale=1.75)
return model
def mobilnetv1_x2_0(pretrained=False):
model = _mobilenet('mobilenetv1_2.0', pretrained, scale=2.0)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import math
import sys
import numpy as np
import argparse
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
from model import Model
from .download import get_weights_path
__all__ = [
'MobileNetV2', 'mobilnetv2_x0_25', 'mobilnetv2_x0_5', 'mobilnetv2_x0_75',
'mobilnetv2_x1_0', 'mobilnetv2_x1_25', 'mobilnetv2_x1_5',
'mobilnetv2_x1_75', 'mobilnetv2_x2_0'
]
model_urls = {}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
use_cudnn=True):
super(ConvBNLayer, self).__init__()
tmp_param = ParamAttr(name=self.full_name() + "_weights")
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=tmp_param,
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs, if_act=True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor, ):
super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
self._bottleneck_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
use_cudnn=False)
self._linear_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
def __init__(self, in_c, t, c, n, s):
super(InvresiBlocks, self).__init__()
self._first_block = InvertedResidualUnit(
num_channels=in_c,
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
self._inv_blocks = []
for i in range(1, n):
tmp = self.add_sublayer(
sublayer=InvertedResidualUnit(
num_channels=c,
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t),
name=self.full_name() + "_" + str(i + 1))
self._inv_blocks.append(tmp)
def forward(self, inputs):
y = self._first_block(inputs, ifshortcut=False)
for inv_block in self._inv_blocks:
y = inv_block(y, ifshortcut=True)
return y
class MobileNetV2(Model):
def __init__(self, class_dim=1000, scale=1.0):
super(MobileNetV2, self).__init__()
self.scale = scale
self.class_dim = class_dim
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#1. conv1
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1)
#2. bottleneck sequences
self._invl = []
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
tmp = self.add_sublayer(
sublayer=InvresiBlocks(
in_c=in_c, t=t, c=int(c * scale), n=n, s=s),
name='conv' + str(i))
self._invl.append(tmp)
in_c = int(c * scale)
#3. last_conv
self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer(
num_channels=in_c,
num_filters=self._out_c,
filter_size=1,
stride=1,
padding=0)
#4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
#5. fc
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
class_dim,
act='softmax',
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilnetv2_x1_0(pretrained=False):
model = _mobilenet('mobilenetv2_1.0', pretrained, scale=1.0)
return model
def mobilnetv2_x0_25(pretrained=False):
model = _mobilenet('mobilenetv2_0.25', pretrained, scale=0.25)
return model
def mobilnetv2_x0_5(pretrained=False):
model = _mobilenet('mobilenetv2_0.5', pretrained, scale=0.5)
return model
def mobilnetv2_x0_75(pretrained=False):
model = _mobilenet('mobilenetv2_0.75', pretrained, scale=0.75)
return model
def mobilnetv2_x1_25(pretrained=False):
model = _mobilenet('mobilenetv2_1.25', pretrained, scale=1.25)
return model
def mobilnetv2_x1_5(pretrained=False):
model = _mobilenet('mobilenetv2_1.5', pretrained, scale=1.5)
return model
def mobilnetv2_x1_75(pretrained=False):
model = _mobilenet('mobilenetv2_1.75', pretrained, scale=1.75)
return model
def mobilnetv2_x2_0(pretrained=False):
model = _mobilenet('mobilenetv2_2.0', pretrained, scale=2.0)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册