提交 98632f46 编写于 作者: G guosheng

Add the Inception-ResNet-v2 model

上级 bcc36ce6
图像分类
=======================
这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)
这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet和Inception-ResNet-v2模型进行图像分类。图像分类问题的描述和这五种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)
## 训练模型
......@@ -11,6 +11,8 @@
```python
import gzip
import argparse
import paddle.v2.dataset.flowers as flowers
import paddle.v2 as paddle
import reader
......@@ -18,6 +20,7 @@ import vgg
import resnet
import alexnet
import googlenet
import inception_resnet_v2
# PaddlePaddle init
......@@ -29,7 +32,7 @@ paddle.init(use_gpu=False, trainer_count=1)
设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`
```python
DATA_DIM = 3 * 224 * 224
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
CLASS_DIM = 102
BATCH_SIZE = 128
......@@ -41,7 +44,7 @@ lbl = paddle.layer.data(
### 获得所用模型
这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。
这里可以选择使用AlexNet、VGG、GoogLeNet、ResNet和Inception-ResNet-v2模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。
1. 使用AlexNet模型
......@@ -86,6 +89,16 @@ ResNet模型可以通过下面的代码获取:
out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM)
```
5. 使用Inception-ResNet-v2模型
提供的Inception-ResNet-v2模型支持`3 * 331 * 331``3 * 299 * 299`两种大小的输入,同时可以自行设置dropout概率,可以通过如下的代码使用:
```python
out = inception_resnet_v2.inception_resnet_v2(image, class_dim=CLASS_DIM, dropout_rate=0.5, size=DATA_DIM)
```
注意,由于和其他几种模型输入大小不同,若配合提供的`reader.py`使用Inception-ResNet-v2时请先将`reader.py``paddle.image.simple_transform`中的参数为修改为相应大小。
### 定义损失函数
```python
......@@ -173,7 +186,7 @@ def event_handler(event):
### 定义训练方法
对于AlexNet、VGG和ResNet,可以按下面的代码定义训练方法:
对于AlexNet、VGG、ResNet和Inception-ResNet-v2,可以按下面的代码定义训练方法:
```python
# Create trainer
......
import paddle.v2 as paddle
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding=0,
active_type=paddle.activation.Relu(),
ch_in=None):
tmp = paddle.layer.img_conv(
input=input,
filter_size=filter_size,
num_channels=ch_in,
num_filters=ch_out,
stride=stride,
padding=padding,
act=paddle.activation.Linear(),
bias_attr=False)
return paddle.layer.batch_norm(input=tmp, epsilon=0.001, act=active_type)
def sequential_block(input, *layers):
for layer in layers:
layer_func, layer_conf = layer
input = layer_func(input, **layer_conf)
return input
def mixed_5b_block(input):
branch0 = conv_bn_layer(
input, ch_in=192, ch_out=96, filter_size=1, stride=1)
branch1 = sequential_block(input, (conv_bn_layer, {
"ch_in": 192,
"ch_out": 48,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 48,
"ch_out": 64,
"filter_size": 5,
"stride": 1,
"padding": 2
}))
branch2 = sequential_block(input, (conv_bn_layer, {
"ch_in": 192,
"ch_out": 64,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 64,
"ch_out": 96,
"filter_size": 3,
"stride": 1,
"padding": 1
}), (conv_bn_layer, {
"ch_in": 96,
"ch_out": 96,
"filter_size": 3,
"stride": 1,
"padding": 1
}))
branch3 = sequential_block(
input,
(paddle.layer.img_pool, {
"pool_size": 3,
"stride": 1,
"padding": 1,
"pool_type": paddle.pooling.Avg(),
"exclude_mode": False
}),
(conv_bn_layer, {
"ch_in": 192,
"ch_out": 64,
"filter_size": 1,
"stride": 1
}), )
out = paddle.layer.concat(input=[branch0, branch1, branch2, branch3])
return out
def block35(input, scale=1.0):
branch0 = conv_bn_layer(
input, ch_in=320, ch_out=32, filter_size=1, stride=1)
branch1 = sequential_block(input, (conv_bn_layer, {
"ch_in": 320,
"ch_out": 32,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 32,
"ch_out": 32,
"filter_size": 3,
"stride": 1,
"padding": 1
}))
branch2 = sequential_block(input, (conv_bn_layer, {
"ch_in": 320,
"ch_out": 32,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 32,
"ch_out": 48,
"filter_size": 3,
"stride": 1,
"padding": 1
}), (conv_bn_layer, {
"ch_in": 48,
"ch_out": 64,
"filter_size": 3,
"stride": 1,
"padding": 1
}))
out = paddle.layer.concat(input=[branch0, branch1, branch2])
out = paddle.layer.img_conv(
input=out,
filter_size=1,
num_channels=128,
num_filters=320,
stride=1,
padding=0,
act=paddle.activation.Linear(),
bias_attr=None)
out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0)
out = paddle.layer.addto(input=[input, out], act=paddle.activation.Relu())
return out
def mixed_6a_block(input):
branch0 = conv_bn_layer(
input, ch_in=320, ch_out=384, filter_size=3, stride=2)
branch1 = sequential_block(input, (conv_bn_layer, {
"ch_in": 320,
"ch_out": 256,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 256,
"ch_out": 256,
"filter_size": 3,
"stride": 1,
"padding": 1
}), (conv_bn_layer, {
"ch_in": 256,
"ch_out": 384,
"filter_size": 3,
"stride": 2
}))
branch2 = paddle.layer.img_pool(
input,
num_channels=320,
pool_size=3,
stride=2,
pool_type=paddle.pooling.Max())
out = paddle.layer.concat(input=[branch0, branch1, branch2])
return out
def block17(input, scale=1.0):
branch0 = conv_bn_layer(
input, ch_in=1088, ch_out=192, filter_size=1, stride=1)
branch1 = sequential_block(input, (conv_bn_layer, {
"ch_in": 1088,
"ch_out": 128,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 128,
"ch_out": 160,
"filter_size": [7, 1],
"stride": 1,
"padding": [3, 0]
}), (conv_bn_layer, {
"ch_in": 160,
"ch_out": 192,
"filter_size": [1, 7],
"stride": 1,
"padding": [0, 3]
}))
out = paddle.layer.concat(input=[branch0, branch1])
out = paddle.layer.img_conv(
input=out,
filter_size=1,
num_channels=384,
num_filters=1088,
stride=1,
padding=0,
act=paddle.activation.Linear(),
bias_attr=None)
out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0)
out = paddle.layer.addto(input=[input, out], act=paddle.activation.Relu())
return out
def mixed_7a_block(input):
branch0 = sequential_block(
input,
(conv_bn_layer, {
"ch_in": 1088,
"ch_out": 256,
"filter_size": 1,
"stride": 1
}),
(conv_bn_layer, {
"ch_in": 256,
"ch_out": 384,
"filter_size": 3,
"stride": 2
}), )
branch1 = sequential_block(
input,
(conv_bn_layer, {
"ch_in": 1088,
"ch_out": 256,
"filter_size": 1,
"stride": 1
}),
(conv_bn_layer, {
"ch_in": 256,
"ch_out": 288,
"filter_size": 3,
"stride": 2
}), )
branch2 = sequential_block(input, (conv_bn_layer, {
"ch_in": 1088,
"ch_out": 256,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 256,
"ch_out": 288,
"filter_size": 3,
"stride": 1,
"padding": 1
}), (conv_bn_layer, {
"ch_in": 288,
"ch_out": 320,
"filter_size": 3,
"stride": 2
}))
branch3 = paddle.layer.img_pool(
input,
num_channels=1088,
pool_size=3,
stride=2,
pool_type=paddle.pooling.Max())
out = paddle.layer.concat(input=[branch0, branch1, branch2, branch3])
return out
def block8(input, scale=1.0, no_relu=False):
branch0 = conv_bn_layer(
input, ch_in=2080, ch_out=192, filter_size=1, stride=1)
branch1 = sequential_block(input, (conv_bn_layer, {
"ch_in": 2080,
"ch_out": 192,
"filter_size": 1,
"stride": 1
}), (conv_bn_layer, {
"ch_in": 192,
"ch_out": 224,
"filter_size": [3, 1],
"stride": 1,
"padding": [1, 0]
}), (conv_bn_layer, {
"ch_in": 224,
"ch_out": 256,
"filter_size": [1, 3],
"stride": 1,
"padding": [0, 1]
}))
out = paddle.layer.concat(input=[branch0, branch1])
out = paddle.layer.img_conv(
input=out,
filter_size=1,
num_channels=448,
num_filters=2080,
stride=1,
padding=0,
act=paddle.activation.Linear(),
bias_attr=None)
out = paddle.layer.slope_intercept(out, slope=scale, intercept=0.0)
out = paddle.layer.addto(
input=[input, out],
act=paddle.activation.Linear() if no_relu else paddle.activation.Relu())
return out
def inception_resnet_v2(input,
class_dim,
dropout_rate=0.5,
data_dim=3 * 331 * 331):
conv2d_1a = conv_bn_layer(
input, ch_in=3, ch_out=32, filter_size=3, stride=2)
conv2d_2a = conv_bn_layer(
conv2d_1a, ch_in=32, ch_out=32, filter_size=3, stride=1)
conv2d_2b = conv_bn_layer(
conv2d_2a, ch_in=32, ch_out=64, filter_size=3, stride=1, padding=1)
maxpool_3a = paddle.layer.img_pool(
input=conv2d_2b, pool_size=3, stride=2, pool_type=paddle.pooling.Max())
conv2d_3b = conv_bn_layer(
maxpool_3a, ch_in=64, ch_out=80, filter_size=1, stride=1)
conv2d_4a = conv_bn_layer(
conv2d_3b, ch_in=80, ch_out=192, filter_size=3, stride=1)
maxpool_5a = paddle.layer.img_pool(
input=conv2d_4a, pool_size=3, stride=2, pool_type=paddle.pooling.Max())
mixed_5b = mixed_5b_block(maxpool_5a)
repeat = sequential_block(mixed_5b, *([(block35, {"scale": 0.17})] * 10))
mixed_6a = mixed_6a_block(repeat)
repeat1 = sequential_block(mixed_6a, *([(block17, {"scale": 0.10})] * 20))
mixed_7a = mixed_7a_block(repeat1)
repeat2 = sequential_block(mixed_7a, *([(block8, {"scale": 0.20})] * 9))
block_8 = block8(repeat2, no_relu=True)
conv2d_7b = conv_bn_layer(
block_8, ch_in=2080, ch_out=1536, filter_size=1, stride=1)
avgpool_1a = paddle.layer.img_pool(
input=conv2d_7b,
pool_size=8 if data_dim == 3 * 299 * 299 else 9,
stride=1,
pool_type=paddle.pooling.Avg(),
exclude_mode=False)
drop_out = paddle.layer.dropout(input=avgpool_1a, dropout_rate=dropout_rate)
out = paddle.layer.fc(
input=drop_out, size=class_dim, act=paddle.activation.Softmax())
return out
import os
import gzip
import argparse
import numpy as np
from PIL import Image
import paddle.v2 as paddle
import reader
import vgg
import resnet
import alexnet
import googlenet
import argparse
import os
from PIL import Image
import numpy as np
import inception_resnet_v2
WIDTH = 224
HEIGHT = 224
DATA_DIM = 3 * WIDTH * HEIGHT
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
CLASS_DIM = 102
......@@ -26,7 +26,10 @@ def main():
parser.add_argument(
'model',
help='The model for image classification',
choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet'])
choices=[
'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet',
'inception-resnet-v2'
])
parser.add_argument(
'params_path', help='The file which stores the parameters')
args = parser.parse_args()
......@@ -49,6 +52,10 @@ def main():
out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM)
elif args.model == 'googlenet':
out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM)
elif args.model == 'inception-resnet-v2':
assert DATA_DIM == 3 * 331 * 331 or DATA_DIM == 3 * 299 * 299
out = inception_resnet_v2.inception_resnet_v2(
image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM)
# load parameters
with gzip.open(args.params_path, 'r') as f:
......
import gzip
import argparse
import paddle.v2.dataset.flowers as flowers
import paddle.v2 as paddle
import reader
......@@ -6,9 +8,9 @@ import vgg
import resnet
import alexnet
import googlenet
import argparse
import inception_resnet_v2
DATA_DIM = 3 * 224 * 224
DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2.
CLASS_DIM = 102
BATCH_SIZE = 128
......@@ -19,7 +21,10 @@ def main():
parser.add_argument(
'model',
help='The model for image classification',
choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet'])
choices=[
'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet',
'inception-resnet-v2'
])
args = parser.parse_args()
# PaddlePaddle init
......@@ -52,6 +57,10 @@ def main():
input=out2, label=lbl, coeff=0.3)
paddle.evaluator.classification_error(input=out2, label=lbl)
extra_layers = [loss1, loss2]
elif args.model == 'inception-resnet-v2':
assert DATA_DIM == 3 * 331 * 331 or DATA_DIM == 3 * 299 * 299
out = inception_resnet_v2.inception_resnet_v2(
image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM)
cost = paddle.layer.classification_cost(input=out, label=lbl)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册