diff --git a/README.cn.md b/README.cn.md index d57811ead7b69166931b1a8ec873a8bdb80ad5ba..72fb35ff3b239d8fa5e226f84aa09f084f593697 100644 --- a/README.cn.md +++ b/README.cn.md @@ -98,7 +98,7 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式 图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,也是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。 -在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet、Inception-v4和Inception-Resnet-V2模型。同时提供了能够将Caffe或TensorFlow训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 +在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、Inception-Resnet-V2和Xception模型。同时提供了能够将Caffe或TensorFlow训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 - 11.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 11.2 [将TensorFlow模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/tf2paddle) @@ -107,6 +107,7 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式 - 11.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.6 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.7 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.8 [Xception](https://github.com/PaddlePaddle/models/tree/develop/image_classification) ## 12. 目标检测 diff --git a/README.md b/README.md index 51fbeb7efcb323fa2a4323275fe9ca88f1a0f3b6..1920b307253c9cc2f4547df1886124f85bfdabb2 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ As an example for sequence-to-sequence learning, we take the machine translation ## 9. Image classification -For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet, Inception-v4 and Inception-Resnet-V2 models in PaddlePaddle. It also provides model conversion tools that convert Caffe or TensorFlow trained model files into PaddlePaddle model files. +For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet, Inception-v4, Inception-Resnet-V2 and Xception models in PaddlePaddle. It also provides model conversion tools that convert Caffe or TensorFlow trained model files into PaddlePaddle model files. - 9.1 [convert Caffe model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 9.2 [convert TensorFlow model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/tf2paddle) @@ -81,5 +81,6 @@ For the example of image classification, we show you how to train AlexNet, VGG, - 9.5 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.6 [Inception-v4](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.7 [Inception-Resnet-V2](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.8 [Xception](https://github.com/PaddlePaddle/models/tree/develop/image_classification) This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](LICENSE). diff --git a/image_classification/README.md b/image_classification/README.md index 49f461fb30937540909480a187060f5047e7ea39..45d8ce5742393ae705e8d16cbf6b0f4e33df5c6a 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1,7 +1,7 @@ 图像分类 ======================= -这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet、Inception-v4和Inception-ResNet-v2模型进行图像分类。图像分类问题的描述和这些模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、Inception-ResNet-v2和Xception模型进行图像分类。图像分类问题的描述和这些模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 ## 训练模型 @@ -22,6 +22,7 @@ import alexnet import googlenet import inception_v4 import inception_resnet_v2 +import xception # PaddlePaddle init @@ -47,7 +48,7 @@ lbl = paddle.layer.data( ### 获得所用模型 -这里可以选择使用AlexNet、VGG、GoogLeNet、ResNet、Inception-v4和Inception-ResNet-v2模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 +这里可以选择使用AlexNet、VGG、GoogLeNet、ResNet、Inception-v4、Inception-ResNet-v2和Xception模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 1. 使用AlexNet模型 @@ -112,6 +113,14 @@ out = inception_resnet_v2.inception_resnet_v2( 注意,由于和其他几种模型输入大小不同,若配合提供的`reader.py`使用Inception-ResNet-v2时请先将`reader.py`中`paddle.image.simple_transform`中的参数为修改为相应大小。 +7. 使用Xception模型 + +Xception模型可以通过下面的代码获取: + +```python +out = xception.xception(image, class_dim=CLASS_DIM) +``` + ### 定义损失函数 ```python @@ -199,7 +208,7 @@ def event_handler(event): ### 定义训练方法 -对于AlexNet、VGG、ResNet、Inception-v4和Inception-ResNet-v2,可以按下面的代码定义训练方法: +对于AlexNet、VGG、ResNet、Inception-v4、Inception-ResNet-v2和Xception,可以按下面的代码定义训练方法: ```python # Create trainer diff --git a/image_classification/infer.py b/image_classification/infer.py index d28824bced5d4c7a64ee4d0dc4469af6ce7ea65a..c73a0c811682209d8cea38089b9db0d315057c00 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -12,6 +12,7 @@ import alexnet import googlenet import inception_v4 import inception_resnet_v2 +import xception DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. CLASS_DIM = 102 @@ -29,7 +30,7 @@ def main(): help='The model for image classification', choices=[ 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', - 'inception-resnet-v2', 'inception_v4' + 'inception-resnet-v2', 'inception_v4', 'xception' ]) parser.add_argument( 'params_path', help='The file which stores the parameters') @@ -59,6 +60,8 @@ def main(): image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM) elif args.model == 'inception_v4': out = inception_v4.inception_v4(image, class_dim=CLASS_DIM) + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) # load parameters with gzip.open(args.params_path, 'r') as f: diff --git a/image_classification/train.py b/image_classification/train.py index 5c7e70d9999d5b5014d796e615dd99729eccabc4..237204620169a4f6adf2c53eb4a430d685bddae4 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -10,6 +10,7 @@ import alexnet import googlenet import inception_v4 import inception_resnet_v2 +import xception DATA_DIM = 3 * 224 * 224 # Use 3 * 331 * 331 or 3 * 299 * 299 for Inception-ResNet-v2. CLASS_DIM = 102 @@ -24,7 +25,7 @@ def main(): help='The model for image classification', choices=[ 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', - 'inception-resnet-v2', 'inception_v4' + 'inception-resnet-v2', 'inception_v4', 'xception' ]) args = parser.parse_args() @@ -64,6 +65,8 @@ def main(): image, class_dim=CLASS_DIM, dropout_rate=0.5, data_dim=DATA_DIM) elif args.model == 'inception_v4': out = inception_v4.inception_v4(image, class_dim=CLASS_DIM) + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) cost = paddle.layer.classification_cost(input=out, label=lbl) diff --git a/image_classification/xception.py b/image_classification/xception.py new file mode 100644 index 0000000000000000000000000000000000000000..41c11b8353c5866a1119dab94d75ba45b374c98d --- /dev/null +++ b/image_classification/xception.py @@ -0,0 +1,193 @@ +import paddle.v2 as paddle + +__all__ = ['xception'] + + +def img_separable_conv_bn(name, input, num_channels, num_out_channels, + filter_size, stride, padding, act): + conv = paddle.networks.img_separable_conv( + name=name, + input=input, + num_channels=num_channels, + num_out_channels=num_out_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def img_conv_bn(name, input, num_channels, num_filters, filter_size, stride, + padding, act): + conv = paddle.layer.img_conv( + name=name, + input=input, + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def conv_block0(input, + group, + num_channels, + num_filters, + num_filters2=None, + filter_size=3, + pool_padding=0, + entry_relu=True): + if num_filters2 is None: + num_filters2 = num_filters + + if entry_relu: + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + else: + act_input = input + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters2, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + pool0 = paddle.layer.img_pool( + name='xception_block{0}_pool'.format(group), + input=conv1, + pool_size=3, + stride=2, + padding=pool_padding, + num_channels=num_filters2, + pool_type=paddle.pooling.CudnnMax()) + + shortcut = img_conv_bn( + name='xception_block{0}_shortcut'.format(group), + input=input, + num_channels=num_channels, + num_filters=num_filters2, + filter_size=1, + stride=2, + padding=0, + act=paddle.activation.Linear()) + + return paddle.layer.addto( + input=[pool0, shortcut], act=paddle.activation.Linear()) + + +def conv_block1(input, group, num_channels, num_filters, filter_size=3): + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv2 = img_separable_conv_bn( + name='xception_block{0}_conv2'.format(group), + input=conv1, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + + shortcut = input + return paddle.layer.addto( + input=[conv2, shortcut], act=paddle.activation.Linear()) + + +def xception(input, class_dim): + conv = img_conv_bn( + name='xception_conv0', + input=input, + num_channels=3, + num_filters=32, + filter_size=3, + stride=2, + padding=1, + act=paddle.activation.Relu()) + conv = img_conv_bn( + name='xception_conv1', + input=conv, + num_channels=32, + num_filters=64, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = conv_block0( + input=conv, group=2, num_channels=64, num_filters=128, entry_relu=False) + conv = conv_block0(input=conv, group=3, num_channels=128, num_filters=256) + conv = conv_block0(input=conv, group=4, num_channels=256, num_filters=728) + for group in range(5, 13): + conv = conv_block1( + input=conv, group=group, num_channels=728, num_filters=728) + conv = conv_block0( + input=conv, + group=13, + num_channels=728, + num_filters=728, + num_filters2=1024) + conv = img_separable_conv_bn( + name='xception_conv14', + input=conv, + num_channels=1024, + num_out_channels=1536, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = img_separable_conv_bn( + name='xception_conv15', + input=conv, + num_channels=1536, + num_out_channels=2048, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + pool = paddle.layer.img_pool( + name='xception_global_pool', + input=conv, + pool_size=7, + stride=1, + num_channels=2048, + pool_type=paddle.pooling.CudnnAvg()) + out = paddle.layer.fc( + name='xception_fc', + input=pool, + size=class_dim, + act=paddle.activation.Softmax()) + return out