未验证 提交 f2e17b25 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #25 from LielinJiang/more-cls-models

add some image classification model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
import cv2 import cv2
...@@ -6,11 +20,11 @@ from paddle.fluid.io import Dataset ...@@ -6,11 +20,11 @@ from paddle.fluid.io import Dataset
def has_valid_extension(filename, extensions): def has_valid_extension(filename, extensions):
"""Checks if a file is an allowed extension. """Checks if a file is a vilid extension.
Args: Args:
filename (string): path to a file filename (str): path to a file
extensions (tuple of strings): extensions to consider (lowercase) extensions (tuple of str): extensions to consider (lowercase)
Returns: Returns:
bool: True if the filename ends with one of given extensions bool: True if the filename ends with one of given extensions
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
```bash ```bash
python -u main.py --arch resnet50 /path/to/imagenet -d python -u main.py --arch resnet50 /path/to/imagenet -d
``` ```
-d 是使用动态模式训练,默认为静态图模式。
### 多卡训练 ### 多卡训练
执行如下命令进行训练 执行如下命令进行训练
...@@ -64,11 +65,28 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch ...@@ -64,11 +65,28 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch main.py --arch
* **output-dir**: 模型文件保存的文件夹,默认值:'output' * **output-dir**: 模型文件保存的文件夹,默认值:'output'
* **num-workers**: dataloader的进程数,默认值:4 * **num-workers**: dataloader的进程数,默认值:4
* **resume**: 恢复训练的模型路径,默认值:None * **resume**: 恢复训练的模型路径,默认值:None
* **eval-only**: 仅仅进行预测,默认值:False * **eval-only**: 是否仅仅进行预测
* **lr-scheduler**: 学习率衰减策略,默认值:piecewise
* **milestones**: piecewise学习率衰减策略的边界,默认值:[30, 60, 80]
* **weight-decay**: 模型权重正则化系数,默认值:1e-4
* **momentum**: SGD优化器的动量,默认值:0.9
## 模型 ## 模型
| 模型 | top1 acc | top5 acc | | 模型 | top1 acc | top5 acc |
| --- | --- | --- | | --- | --- | --- |
| ResNet50 | 76.28 | 93.04 | | [ResNet50](https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams) | 76.28 | 93.04 |
| [vgg16](https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams) | 71.84 | 90.71 |
| [mobilenet_v1](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams) | 71.25 | 89.92 |
| [mobilenet_v2](https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams) | 72.27 | 90.66 |
上述模型的复现参数请参考scripts下的脚本。
## 参考文献
- ResNet: [Deep Residual Learning for Image Recognitio](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
- MobileNetV1: [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861), Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, Tobias Weyand, Marco Andreetto, Hartwig Adam
- MobileNetV2: [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/pdf/1801.04381v4.pdf), Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen
- VGG: [Very Deep Convolutional Networks for Large-scale Image Recognition](https://arxiv.org/pdf/1409.1556), Karen Simonyan, Andrew Zisserman
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import math import math
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -37,23 +37,36 @@ from paddle.fluid.io import BatchSampler, DataLoader ...@@ -37,23 +37,36 @@ from paddle.fluid.io import BatchSampler, DataLoader
def make_optimizer(step_per_epoch, parameter_list=None): def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr base_lr = FLAGS.lr
momentum = 0.9 lr_scheduler = FLAGS.lr_scheduler
weight_decay = 1e-4 momentum = FLAGS.momentum
weight_decay = FLAGS.weight_decay
if lr_scheduler == 'piecewise':
milestones = FLAGS.milestones
boundaries = [step_per_epoch * e for e in milestones]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
elif lr_scheduler == 'cosine':
learning_rate = fluid.layers.cosine_decay(base_lr, step_per_epoch,
FLAGS.epoch)
else:
raise ValueError(
"Expected lr_scheduler in ['piecewise', 'cosine'], but got {}".
format(lr_scheduler))
boundaries = [step_per_epoch * e for e in [30, 60, 80]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup( learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate, learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch, warmup_steps=5 * step_per_epoch,
start_lr=0., start_lr=0.,
end_lr=base_lr) end_lr=base_lr)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum, momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay), regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list) parameter_list=parameter_list)
return optimizer return optimizer
...@@ -138,6 +151,20 @@ if __name__ == '__main__': ...@@ -138,6 +151,20 @@ if __name__ == '__main__':
help="checkpoint path to resume") help="checkpoint path to resume")
parser.add_argument( parser.add_argument(
"--eval-only", action='store_true', help="enable dygraph mode") "--eval-only", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--lr-scheduler",
default='piecewise',
type=str,
help="learning rate scheduler")
parser.add_argument(
"--milestones",
nargs='+',
type=int,
default=[30, 60, 80],
help="piecewise decay milestones")
parser.add_argument(
"--weight-decay", default=1e-4, type=float, help="weight decay")
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path" assert FLAGS.data, "error: must provide data path"
main() main()
...@@ -42,6 +42,14 @@ __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input', 'set_device'] ...@@ -42,6 +42,14 @@ __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input', 'set_device']
def set_device(device): def set_device(device):
"""
Args:
device (str): specify device type, 'cpu' or 'gpu'.
Returns:
fluid.CUDAPlace or fluid.CPUPlace: Created GPU or CPU place.
"""
assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \ assert isinstance(device, six.string_types) and device.lower() in ['cpu', 'gpu'], \
"Expected device in ['cpu', 'gpu'], but got {}".format(device) "Expected device in ['cpu', 'gpu'], but got {}".format(device)
...@@ -1082,7 +1090,11 @@ class Model(fluid.dygraph.Layer): ...@@ -1082,7 +1090,11 @@ class Model(fluid.dygraph.Layer):
return eval_result return eval_result
def predict(self, test_data, batch_size=1, num_workers=0, stack_outputs=True): def predict(self,
test_data,
batch_size=1,
num_workers=0,
stack_outputs=True):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args: Args:
......
...@@ -13,13 +13,22 @@ ...@@ -13,13 +13,22 @@
#limitations under the License. #limitations under the License.
from . import resnet from . import resnet
from . import vgg
from . import mobilenetv1
from . import mobilenetv2
from . import darknet from . import darknet
from . import yolov3 from . import yolov3
from .resnet import * from .resnet import *
from .mobilenetv1 import *
from .mobilenetv2 import *
from .vgg import *
from .darknet import * from .darknet import *
from .yolov3 import * from .yolov3 import *
__all__ = resnet.__all__ \ __all__ = resnet.__all__ \
+ vgg.__all__ \
+ mobilenetv1.__all__ \
+ mobilenetv2.__all__ \
+ darknet.__all__ \ + darknet.__all__ \
+ yolov3.__all__ + yolov3.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from model import Model
from .download import get_weights_path
__all__ = ['MobileNetV1', 'mobilenet_v1']
model_urls = {
'mobilenetv1_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v1_x1.0.pdparams',
'bf0d25cb0bed1114d9dac9384ce2b4a6')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(Model):
"""MobileNetV1 model from
`"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" <https://arxiv.org/abs/1704.04861>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
"""
def __init__(self, scale=1.0, class_dim=1000):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear(
int(1024 * scale),
class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV1(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilenet_v1(pretrained=False, scale=1.0):
model = _mobilenet('mobilenetv1_' + str(scale), pretrained, scale=scale)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from model import Model
from .download import get_weights_path
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenetv2_1.0':
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
'8ff74f291f72533f2a7956a4efff9d88')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
use_cudnn=True):
super(ConvBNLayer, self).__init__()
tmp_param = ParamAttr(name=self.full_name() + "_weights")
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=tmp_param,
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs, if_act=True):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_in_filter,
num_filters,
stride,
filter_size,
padding,
expansion_factor, ):
super(InvertedResidualUnit, self).__init__()
num_expfilter = int(round(num_in_filter * expansion_factor))
self._expand_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
self._bottleneck_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
use_cudnn=False)
self._linear_conv = ConvBNLayer(
num_channels=num_expfilter,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1)
def forward(self, inputs, ifshortcut):
y = self._expand_conv(inputs, if_act=True)
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
def __init__(self, in_c, t, c, n, s):
super(InvresiBlocks, self).__init__()
self._first_block = InvertedResidualUnit(
num_channels=in_c,
num_in_filter=in_c,
num_filters=c,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t)
self._inv_blocks = []
for i in range(1, n):
tmp = self.add_sublayer(
sublayer=InvertedResidualUnit(
num_channels=c,
num_in_filter=c,
num_filters=c,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t),
name=self.full_name() + "_" + str(i + 1))
self._inv_blocks.append(tmp)
def forward(self, inputs):
y = self._first_block(inputs, ifshortcut=False)
for inv_block in self._inv_blocks:
y = inv_block(y, ifshortcut=True)
return y
class MobileNetV2(Model):
"""MobileNetV2 model from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
scale (float): scale of channels in each layer. Default: 1.0.
class_dim (int): output dim of last fc layer. Default: 1000.
"""
def __init__(self, scale=1.0, class_dim=1000):
super(MobileNetV2, self).__init__()
self.scale = scale
self.class_dim = class_dim
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#1. conv1
self._conv1 = ConvBNLayer(
num_channels=3,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1)
#2. bottleneck sequences
self._invl = []
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
tmp = self.add_sublayer(
sublayer=InvresiBlocks(
in_c=in_c, t=t, c=int(c * scale), n=n, s=s),
name='conv' + str(i))
self._invl.append(tmp)
in_c = int(c * scale)
#3. last_conv
self._out_c = int(1280 * scale) if scale > 1.0 else 1280
self._conv9 = ConvBNLayer(
num_channels=in_c,
num_filters=self._out_c,
filter_size=1,
stride=1,
padding=0)
#4. pool
self._pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
#5. fc
tmp_param = ParamAttr(name=self.full_name() + "fc10_weights")
self._fc = Linear(
self._out_c,
class_dim,
act='softmax',
param_attr=tmp_param,
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
y = inv(y)
y = self._conv9(y, if_act=True)
y = self._pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self._out_c])
y = self._fc(y)
return y
def _mobilenet(arch, pretrained=False, **kwargs):
model = MobileNetV2(**kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def mobilenet_v2(pretrained=False, scale=1.0):
"""MobileNetV2
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = _mobilenet('mobilenetv2_' + str(scale), pretrained, scale=scale)
return model
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -11,7 +25,9 @@ from paddle.fluid.dygraph.container import Sequential ...@@ -11,7 +25,9 @@ from paddle.fluid.dygraph.container import Sequential
from model import Model from model import Model
from .download import get_weights_path from .download import get_weights_path
__all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] __all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
]
model_urls = { model_urls = {
'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
...@@ -48,7 +64,52 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -48,7 +64,52 @@ class ConvBNLayer(fluid.dygraph.Layer):
return x return x
class BasicBlock(fluid.dygraph.Layer):
expansion = 1
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BasicBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = short + conv1
return fluid.layers.relu(y)
class BottleneckBlock(fluid.dygraph.Layer): class BottleneckBlock(fluid.dygraph.Layer):
expansion = 4
def __init__(self, num_channels, num_filters, stride, shortcut=True): def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
...@@ -65,20 +126,20 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -65,20 +126,20 @@ class BottleneckBlock(fluid.dygraph.Layer):
act='relu') act='relu')
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * self.expansion,
filter_size=1, filter_size=1,
act=None) act=None)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters * 4, num_filters=num_filters * self.expansion,
filter_size=1, filter_size=1,
stride=stride) stride=stride)
self.shortcut = shortcut self.shortcut = shortcut
self._num_channels_out = num_filters * 4 self._num_channels_out = num_filters * self.expansion
def forward(self, inputs): def forward(self, inputs):
x = self.conv0(inputs) x = self.conv0(inputs)
...@@ -92,16 +153,25 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -92,16 +153,25 @@ class BottleneckBlock(fluid.dygraph.Layer):
x = fluid.layers.elementwise_add(x=short, y=conv2) x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu') return fluid.layers.relu(x)
return layer_helper.append_activation(x)
# return fluid.layers.relu(x)
class ResNet(Model): class ResNet(Model):
"""ResNet model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Block (BasicBlock|BottleneckBlock): block module of model.
depth (int): layers of resnet, default: 50.
num_classes (int): output dim of last fc layer, default: 1000.
"""
def __init__(self, Block, depth=50, num_classes=1000): def __init__(self, Block, depth=50, num_classes=1000):
super(ResNet, self).__init__() super(ResNet, self).__init__()
layer_config = { layer_config = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3], 50: [3, 4, 6, 3],
101: [3, 4, 23, 3], 101: [3, 4, 23, 3],
152: [3, 8, 36, 3], 152: [3, 8, 36, 3],
...@@ -111,8 +181,9 @@ class ResNet(Model): ...@@ -111,8 +181,9 @@ class ResNet(Model):
layer_config.keys(), depth) layer_config.keys(), depth)
layers = layer_config[depth] layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512] in_channels = 64
out_channels = [64, 128, 256, 512]
self.conv = ConvBNLayer( self.conv = ConvBNLayer(
num_channels=3, num_channels=3,
...@@ -128,9 +199,11 @@ class ResNet(Model): ...@@ -128,9 +199,11 @@ class ResNet(Model):
blocks = [] blocks = []
shortcut = False shortcut = False
for b in range(num_blocks): for b in range(num_blocks):
if b == 1:
in_channels = out_channels[idx] * Block.expansion
block = Block( block = Block(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4, num_channels=in_channels,
num_filters=num_out[idx], num_filters=out_channels[idx],
stride=2 if b == 0 and idx != 0 else 1, stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut) shortcut=shortcut)
blocks.append(block) blocks.append(block)
...@@ -142,8 +215,8 @@ class ResNet(Model): ...@@ -142,8 +215,8 @@ class ResNet(Model):
self.global_pool = Pool2D( self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True) pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0) stdv = 1.0 / math.sqrt(out_channels[-1] * Block.expansion * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1 self.fc_input_dim = out_channels[-1] * Block.expansion * 1 * 1
self.fc = Linear( self.fc = Linear(
self.fc_input_dim, self.fc_input_dim,
num_classes, num_classes,
...@@ -175,13 +248,46 @@ def _resnet(arch, Block, depth, pretrained): ...@@ -175,13 +248,46 @@ def _resnet(arch, Block, depth, pretrained):
return model return model
def resnet18(pretrained=False):
"""ResNet 18-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _resnet('resnet18', BasicBlock, 18, pretrained)
def resnet34(pretrained=False):
"""ResNet 34-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _resnet('resnet34', BasicBlock, 34, pretrained)
def resnet50(pretrained=False): def resnet50(pretrained=False):
"""ResNet 50-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _resnet('resnet50', BottleneckBlock, 50, pretrained) return _resnet('resnet50', BottleneckBlock, 50, pretrained)
def resnet101(pretrained=False): def resnet101(pretrained=False):
"""ResNet 101-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _resnet('resnet101', BottleneckBlock, 101, pretrained) return _resnet('resnet101', BottleneckBlock, 101, pretrained)
def resnet152(pretrained=False): def resnet152(pretrained=False):
"""ResNet 152-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _resnet('resnet152', BottleneckBlock, 152, pretrained) return _resnet('resnet152', BottleneckBlock, 152, pretrained)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model
from .download import get_weights_path
__all__ = [
'VGG',
'vgg11',
'vgg11_bn',
'vgg13',
'vgg13_bn',
'vgg16',
'vgg16_bn',
'vgg19_bn',
'vgg19',
]
model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
'c788f453a3b999063e8da043456281ee')
}
class Classifier(fluid.dygraph.Layer):
def __init__(self, num_classes):
super(Classifier, self).__init__()
self.linear1 = Linear(512 * 7 * 7, 4096)
self.linear2 = Linear(4096, 4096)
self.linear3 = Linear(4096, num_classes, act='softmax')
def forward(self, x):
x = self.linear1(x)
x = fluid.layers.relu(x)
x = fluid.layers.dropout(x, 0.5)
x = self.linear2(x)
x = fluid.layers.relu(x)
x = fluid.layers.dropout(x, 0.5)
out = self.linear3(x)
return out
class VGG(Model):
"""VGG model from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
features (fluid.dygraph.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. Default: 1000.
"""
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
classifier = Classifier(num_classes)
self.classifier = self.add_sublayer("classifier",
Sequential(classifier))
def forward(self, x):
x = self.features(x)
x = fluid.layers.flatten(x, 1)
x = self.classifier(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [Pool2D(pool_size=2, pool_stride=2)]
else:
if batch_norm:
conv2d = Conv2D(in_channels, v, filter_size=3, padding=1)
layers += [conv2d, BatchNorm(v, act='relu')]
else:
conv2d = Conv2D(
in_channels, v, filter_size=3, padding=1, act='relu')
layers += [conv2d]
in_channels = v
return Sequential(*layers)
cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B':
[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
512, 512, 512, 'M'
],
'E': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512,
512, 'M', 512, 512, 512, 512, 'M'
],
}
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def vgg11(pretrained=False, **kwargs):
"""VGG 11-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg11', 'A', False, pretrained, **kwargs)
def vgg11_bn(pretrained=False, **kwargs):
"""VGG 11-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg11_bn', 'A', True, pretrained, **kwargs)
def vgg13(pretrained=False, **kwargs):
"""VGG 13-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg13', 'B', False, pretrained, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
"""VGG 13-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg13_bn', 'B', True, pretrained, **kwargs)
def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg16', 'D', False, pretrained, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
"""VGG 16-layer with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg16_bn', 'D', True, pretrained, **kwargs)
def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg19', 'E', False, pretrained, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
"""VGG 19-layer model with batch normalization
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg('vgg19_bn', 'E', True, pretrained, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册