From d264cefcbe736b09063aaaf34b2cb666500e74ce Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Tue, 12 Dec 2017 15:32:50 +0800 Subject: [PATCH] Implement Inception-v4 for image classification --- README.cn.md | 3 +- README.md | 3 +- image_classification/README.md | 13 +- image_classification/inception_v4.py | 526 +++++++++++++++++++++++++++ image_classification/infer.py | 8 +- image_classification/train.py | 8 +- 6 files changed, 555 insertions(+), 6 deletions(-) create mode 100644 image_classification/inception_v4.py diff --git a/README.cn.md b/README.cn.md index 9491690e..3b2403bb 100644 --- a/README.cn.md +++ b/README.cn.md @@ -98,12 +98,13 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式 图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,也是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。 -在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet和ResNet模型。同时提供了一个够将Caffe训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 +在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet和Inception-v4模型。同时提供了一个够将Caffe训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 - 11.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 11.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.5 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/image_classification) ## 12. 目标检测 diff --git a/README.md b/README.md index 8b938a30..876c5621 100644 --- a/README.md +++ b/README.md @@ -72,11 +72,12 @@ As an example for sequence-to-sequence learning, we take the machine translation ## 9. Image classification -For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet and ResNet models in PaddlePaddle. It also provides a model conversion tool that converts Caffe trained model files into PaddlePaddle model files. +For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet and Inception-v4 models in PaddlePaddle. It also provides a model conversion tool that converts Caffe trained model files into PaddlePaddle model files. - 9.1 [convert Caffe model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 9.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.5 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/image_classification) This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](LICENSE). diff --git a/image_classification/README.md b/image_classification/README.md index 843d683c..b160d3ef 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1,7 +1,7 @@ 图像分类 ======================= -这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet和Inception-v4模型进行图像分类。图像分类问题的描述和这些模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 ## 训练模型 @@ -18,6 +18,7 @@ import vgg import resnet import alexnet import googlenet +import inception_v4 # PaddlePaddle init @@ -41,7 +42,7 @@ lbl = paddle.layer.data( ### 获得所用模型 -这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 +这里可以选择使用AlexNet、VGG、GoogLeNet、ResNet和Inception-v4模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 1. 使用AlexNet模型 @@ -86,6 +87,14 @@ ResNet模型可以通过下面的代码获取: out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) ``` +5. 使用Inception-v4模型 + +Inception-v4模型可以通过下面的代码获取: + +```python +out = inception_v4.inception_v4(image, class_dim=CLASS_DIM) +``` + ### 定义损失函数 ```python diff --git a/image_classification/inception_v4.py b/image_classification/inception_v4.py new file mode 100644 index 00000000..9a8c5fa8 --- /dev/null +++ b/image_classification/inception_v4.py @@ -0,0 +1,526 @@ +import paddle.v2 as paddle + +__all__ = ['inception_v4'] + + +def img_conv(name, + input, + num_filters, + filter_size, + stride, + padding, + num_channels=None): + conv = paddle.layer.img_conv( + name=name, + input=input, + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm( + name=name + '_norm', input=conv, act=paddle.activation.Relu()) + return norm + + +def stem(input): + conv0 = img_conv( + name='stem_conv0', + input=input, + num_channels=3, + num_filters=32, + filter_size=3, + stride=2, + padding=1) + conv1 = img_conv( + name='stem_conv1', + input=conv0, + num_channels=32, + num_filters=32, + filter_size=3, + stride=1, + padding=1) + conv2 = img_conv( + name='stem_conv2', + input=conv1, + num_channels=32, + num_filters=64, + filter_size=3, + stride=1, + padding=1) + + def block0(input): + pool0 = paddle.layer.img_pool( + name='stem_branch0_pool0', + input=input, + num_channels=64, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + conv0 = img_conv( + name='stem_branch0_conv0', + input=input, + num_channels=64, + num_filters=96, + filter_size=3, + stride=2, + padding=1) + return paddle.layer.concat(input=[pool0, conv0]) + + def block1(input): + l_conv0 = img_conv( + name='stem_branch1_l_conv0', + input=input, + num_channels=160, + num_filters=64, + filter_size=1, + stride=1, + padding=0) + l_conv1 = img_conv( + name='stem_branch1_l_conv1', + input=l_conv0, + num_channels=64, + num_filters=96, + filter_size=3, + stride=1, + padding=1) + r_conv0 = img_conv( + name='stem_branch1_r_conv0', + input=input, + num_channels=160, + num_filters=64, + filter_size=1, + stride=1, + padding=0) + r_conv1 = img_conv( + name='stem_branch1_r_conv1', + input=r_conv0, + num_channels=64, + num_filters=64, + filter_size=(7, 1), + stride=1, + padding=(3, 0)) + r_conv2 = img_conv( + name='stem_branch1_r_conv2', + input=r_conv1, + num_channels=64, + num_filters=64, + filter_size=(1, 7), + stride=1, + padding=(0, 3)) + r_conv3 = img_conv( + name='stem_branch1_r_conv3', + input=r_conv2, + num_channels=64, + num_filters=96, + filter_size=3, + stride=1, + padding=1) + return paddle.layer.concat(input=[l_conv1, r_conv3]) + + def block2(input): + conv0 = img_conv( + name='stem_branch2_conv0', + input=input, + num_channels=192, + num_filters=192, + filter_size=3, + stride=2, + padding=1) + pool0 = paddle.layer.img_pool( + name='stem_branch2_pool0', + input=input, + num_channels=192, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + return paddle.layer.concat(input=[conv0, pool0]) + + conv3 = block0(conv2) + conv4 = block1(conv3) + conv5 = block2(conv4) + return conv5 + + +def Inception_A(input, depth): + b0_pool0 = paddle.layer.img_pool( + name='inceptA{0}_branch0_pool0'.format(depth), + input=input, + num_channels=384, + pool_size=3, + stride=1, + padding=1, + pool_type=paddle.pooling.Avg()) + b0_conv0 = img_conv( + name='inceptA{0}_branch0_conv0'.format(depth), + input=b0_pool0, + num_channels=384, + num_filters=96, + filter_size=1, + stride=1, + padding=0) + b1_conv0 = img_conv( + name='inceptA{0}_branch1_conv0'.format(depth), + input=input, + num_channels=384, + num_filters=96, + filter_size=1, + stride=1, + padding=0) + b2_conv0 = img_conv( + name='inceptA{0}_branch2_conv0'.format(depth), + input=input, + num_channels=384, + num_filters=64, + filter_size=1, + stride=1, + padding=0) + b2_conv1 = img_conv( + name='inceptA{0}_branch2_conv1'.format(depth), + input=b2_conv0, + num_channels=64, + num_filters=96, + filter_size=3, + stride=1, + padding=1) + b3_conv0 = img_conv( + name='inceptA{0}_branch3_conv0'.format(depth), + input=input, + num_channels=384, + num_filters=64, + filter_size=1, + stride=1, + padding=0) + b3_conv1 = img_conv( + name='inceptA{0}_branch3_conv1'.format(depth), + input=b3_conv0, + num_channels=64, + num_filters=96, + filter_size=3, + stride=1, + padding=1) + b3_conv2 = img_conv( + name='inceptA{0}_branch3_conv2'.format(depth), + input=b3_conv1, + num_channels=96, + num_filters=96, + filter_size=3, + stride=1, + padding=1) + return paddle.layer.concat(input=[b0_conv0, b1_conv0, b2_conv1, b3_conv2]) + + +def Inception_B(input, depth): + b0_pool0 = paddle.layer.img_pool( + name='inceptB{0}_branch0_pool0'.format(depth), + input=input, + num_channels=1024, + pool_size=3, + stride=1, + padding=1, + pool_type=paddle.pooling.Avg()) + b0_conv0 = img_conv( + name='inceptB{0}_branch0_conv0'.format(depth), + input=b0_pool0, + num_channels=1024, + num_filters=128, + filter_size=1, + stride=1, + padding=0) + b1_conv0 = img_conv( + name='inceptB{0}_branch1_conv0'.format(depth), + input=input, + num_channels=1024, + num_filters=384, + filter_size=1, + stride=1, + padding=0) + b2_conv0 = img_conv( + name='inceptB{0}_branch2_conv0'.format(depth), + input=input, + num_channels=1024, + num_filters=192, + filter_size=1, + stride=1, + padding=0) + b2_conv1 = img_conv( + name='inceptB{0}_branch2_conv1'.format(depth), + input=b2_conv0, + num_channels=192, + num_filters=224, + filter_size=(1, 7), + stride=1, + padding=(0, 3)) + b2_conv2 = img_conv( + name='inceptB{0}_branch2_conv2'.format(depth), + input=b2_conv1, + num_channels=224, + num_filters=256, + filter_size=(7, 1), + stride=1, + padding=(3, 0)) + b3_conv0 = img_conv( + name='inceptB{0}_branch3_conv0'.format(depth), + input=input, + num_channels=1024, + num_filters=192, + filter_size=1, + stride=1, + padding=0) + b3_conv1 = img_conv( + name='inceptB{0}_branch3_conv1'.format(depth), + input=b3_conv0, + num_channels=192, + num_filters=192, + filter_size=(1, 7), + stride=1, + padding=(0, 3)) + b3_conv2 = img_conv( + name='inceptB{0}_branch3_conv2'.format(depth), + input=b3_conv1, + num_channels=192, + num_filters=224, + filter_size=(7, 1), + stride=1, + padding=(3, 0)) + b3_conv3 = img_conv( + name='inceptB{0}_branch3_conv3'.format(depth), + input=b3_conv2, + num_channels=224, + num_filters=224, + filter_size=(1, 7), + stride=1, + padding=(0, 3)) + b3_conv4 = img_conv( + name='inceptB{0}_branch3_conv4'.format(depth), + input=b3_conv3, + num_channels=224, + num_filters=256, + filter_size=(7, 1), + stride=1, + padding=(3, 0)) + return paddle.layer.concat(input=[b0_conv0, b1_conv0, b2_conv2, b3_conv4]) + + +def Inception_C(input, depth): + b0_pool0 = paddle.layer.img_pool( + name='inceptC{0}_branch0_pool0'.format(depth), + input=input, + num_channels=1536, + pool_size=3, + stride=1, + padding=1, + pool_type=paddle.pooling.Avg()) + b0_conv0 = img_conv( + name='inceptC{0}_branch0_conv0'.format(depth), + input=b0_pool0, + num_channels=1536, + num_filters=256, + filter_size=1, + stride=1, + padding=0) + b1_conv0 = img_conv( + name='inceptC{0}_branch1_conv0'.format(depth), + input=input, + num_channels=1536, + num_filters=256, + filter_size=1, + stride=1, + padding=0) + b2_conv0 = img_conv( + name='inceptC{0}_branch2_conv0'.format(depth), + input=input, + num_channels=1536, + num_filters=384, + filter_size=1, + stride=1, + padding=0) + b2_conv1 = img_conv( + name='inceptC{0}_branch2_conv1'.format(depth), + input=b2_conv0, + num_channels=384, + num_filters=256, + filter_size=(1, 3), + stride=1, + padding=(0, 1)) + b2_conv2 = img_conv( + name='inceptC{0}_branch2_conv2'.format(depth), + input=b2_conv0, + num_channels=384, + num_filters=256, + filter_size=(3, 1), + stride=1, + padding=(1, 0)) + b3_conv0 = img_conv( + name='inceptC{0}_branch3_conv0'.format(depth), + input=input, + num_channels=1536, + num_filters=384, + filter_size=1, + stride=1, + padding=0) + b3_conv1 = img_conv( + name='inceptC{0}_branch3_conv1'.format(depth), + input=b3_conv0, + num_channels=384, + num_filters=448, + filter_size=(1, 3), + stride=1, + padding=(0, 1)) + b3_conv2 = img_conv( + name='inceptC{0}_branch3_conv2'.format(depth), + input=b3_conv1, + num_channels=448, + num_filters=512, + filter_size=(3, 1), + stride=1, + padding=(1, 0)) + b3_conv3 = img_conv( + name='inceptC{0}_branch3_conv3'.format(depth), + input=b3_conv2, + num_channels=512, + num_filters=256, + filter_size=(3, 1), + stride=1, + padding=(1, 0)) + b3_conv4 = img_conv( + name='inceptC{0}_branch3_conv4'.format(depth), + input=b3_conv2, + num_channels=512, + num_filters=256, + filter_size=(1, 3), + stride=1, + padding=(0, 1)) + return paddle.layer.concat( + input=[b0_conv0, b1_conv0, b2_conv1, b2_conv2, b3_conv3, b3_conv4]) + + +def Reduction_A(input): + b0_pool0 = paddle.layer.img_pool( + name='ReductA_branch0_pool0', + input=input, + num_channels=384, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + b1_conv0 = img_conv( + name='ReductA_branch1_conv0', + input=input, + num_channels=384, + num_filters=384, + filter_size=3, + stride=2, + padding=1) + b2_conv0 = img_conv( + name='ReductA_branch2_conv0', + input=input, + num_channels=384, + num_filters=192, + filter_size=1, + stride=1, + padding=0) + b2_conv1 = img_conv( + name='ReductA_branch2_conv1', + input=b2_conv0, + num_channels=192, + num_filters=224, + filter_size=3, + stride=1, + padding=1) + b2_conv2 = img_conv( + name='ReductA_branch2_conv2', + input=b2_conv1, + num_channels=224, + num_filters=256, + filter_size=3, + stride=2, + padding=1) + return paddle.layer.concat(input=[b0_pool0, b1_conv0, b2_conv2]) + + +def Reduction_B(input): + b0_pool0 = paddle.layer.img_pool( + name='ReductB_branch0_pool0', + input=input, + num_channels=1024, + pool_size=3, + stride=2, + pool_type=paddle.pooling.Max()) + b1_conv0 = img_conv( + name='ReductB_branch1_conv0', + input=input, + num_channels=1024, + num_filters=192, + filter_size=1, + stride=1, + padding=0) + b1_conv1 = img_conv( + name='ReductB_branch1_conv1', + input=b1_conv0, + num_channels=192, + num_filters=192, + filter_size=3, + stride=2, + padding=1) + b2_conv0 = img_conv( + name='ReductB_branch2_conv0', + input=input, + num_channels=1024, + num_filters=256, + filter_size=1, + stride=1, + padding=0) + b2_conv1 = img_conv( + name='ReductB_branch2_conv1', + input=b2_conv0, + num_channels=256, + num_filters=256, + filter_size=(1, 7), + stride=1, + padding=(0, 3)) + b2_conv2 = img_conv( + name='ReductB_branch2_conv2', + input=b2_conv1, + num_channels=256, + num_filters=320, + filter_size=(7, 1), + stride=1, + padding=(3, 0)) + b2_conv3 = img_conv( + name='ReductB_branch2_conv3', + input=b2_conv2, + num_channels=320, + num_filters=320, + filter_size=3, + stride=2, + padding=1) + return paddle.layer.concat(input=[b0_pool0, b1_conv1, b2_conv3]) + + +def inception_v4(input, class_dim): + conv = stem(input) + + for i in range(4): + conv = Inception_A(conv, i) + conv = Reduction_A(conv) + for i in range(7): + conv = Inception_B(conv, i) + conv = Reduction_B(conv) + for i in range(3): + conv = Inception_C(conv, i) + + pool = paddle.layer.img_pool( + name='incept_avg_pool', + input=conv, + num_channels=1536, + pool_size=7, + stride=1, + pool_type=paddle.pooling.Avg()) + drop = paddle.layer.dropout(input=pool, dropout_rate=0.2) + out = paddle.layer.fc( + name='incept_fc', + input=drop, + size=class_dim, + act=paddle.activation.Softmax()) + return out diff --git a/image_classification/infer.py b/image_classification/infer.py index 659c4f2a..e7a0b98e 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -5,6 +5,7 @@ import vgg import resnet import alexnet import googlenet +import inception_v4 import argparse import os from PIL import Image @@ -26,7 +27,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'inception_v4' + ]) parser.add_argument( 'params_path', help='The file which stores the parameters') args = parser.parse_args() @@ -49,6 +53,8 @@ def main(): out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM) + elif args.model == 'inception_v4': + out = inception_v4.inception_v4(image, class_dim=CLASS_DIM) # load parameters with gzip.open(args.params_path, 'r') as f: diff --git a/image_classification/train.py b/image_classification/train.py index 12a582db..8f343f97 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -6,6 +6,7 @@ import vgg import resnet import alexnet import googlenet +import inception_v4 import argparse DATA_DIM = 3 * 224 * 224 @@ -19,7 +20,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'inception_v4' + ]) args = parser.parse_args() # PaddlePaddle init @@ -52,6 +56,8 @@ def main(): input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] + elif args.model == 'inception_v4': + out = inception_v4.inception_v4(image, class_dim=CLASS_DIM) cost = paddle.layer.classification_cost(input=out, label=lbl) -- GitLab