diff --git a/image_classification/.gitignore b/image_classification/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dc7c62b06287ad333dd41082e566b0553d3a5341 --- /dev/null +++ b/image_classification/.gitignore @@ -0,0 +1,8 @@ +*.pyc +train.log +output +data/cifar-10-batches-py/ +data/cifar-10-python.tar.gz +data/*.txt +data/*.list +data/mean.meta diff --git a/image_classification/README.md b/image_classification/README.md index 77e5b5afa279823200b9127c554f076a86488149..032b56c8c27c6006b18e9f0d15e8501d67c67dcc 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1 +1,271 @@ -TODO: Write about https://github.com/PaddlePaddle/Paddle/tree/develop/demo/image_classification +图像分类 +======= + +## 背景介绍 + +图像分类是计算机视觉中的一个核心问题,图像分类是根据图像的语义信息将不同类别图像区分开来的方法。一般来说,图像分类包括训练和预测两个阶段。在训练阶段,输入训练图片集合和每一张训练图片对应的标签,计算机学习得到预测函数。在测试阶段,输入无标签的测试图片,计算机输出测试图片所属的类别标签。 + +在图像分类任务中,如何提取图像的特征是至关重要的。图像的颜色、纹理、形状各自描述了图像的视觉特性,但各自丢失了一部分原始图像中的信息。基于深度学习的图像分类方法,利用图像像素信息作为输入,最大程度上保留了输入图像的所有信息;与此同时,采用卷积神经网络进行特征的提取和高层抽象,从而得到远超过传统方法的分类性能。 + +## 效果展示 + +图像分类包括通用图像分类、细粒度图像分类等。下图展示了通用图像分类效果,即模型可以正确识别图像上的主要物体。 + +

+
+图1. 通用物体分类展示 +

+ + +下图展示了细粒度图像分类-花卉识别的效果,要求模型可以正确识别花的类别。 + + +

+
+图2. 细粒度图像分类展示 +

+ + +一个好的模型即要对不同类别识别正确,同时也应该能够对变形、扰动后的图像正确识别,下图展示了一些图像的扰动,较好的模型会像人一样能够正确识别。 + +

+
+图3. 扰动图片展示 +

+ +## 模型概览 + +### VGG + +[VGG](https://arxiv.org/abs/1405.3531) 模型的核心是五组卷积操作,每两组之间做max-pooling空间降维。同一组内采用多次连续的3X3卷积,卷积核的数目由较浅组的64增多到最深组的512,同一组内的卷积核数目是一样的。卷积之后接两层全连接层,之后是分类层。VGG模型的计算量较大,收敛较慢。 + +### ResNet + +[ResNet](https://arxiv.org/abs/1512.03385) 是2015年ImageNet分类定位、检测比赛的冠军。针对训练卷积神经网络时加深网络导致准确度下降的问题,提出了采用残差学习。在已有设计思路(Batch Norm, 小卷积核,全卷积网络)的基础上,引入了残差模块。每个残差模块包含两条路径,其中一条路径是输入特征的直连通路,另一条路径对该特征做两到三次卷积操作得到该特征的残差,最后再将两条路径上的特征相加。ResNet成功的训练了上百乃至近千层的卷积神经网络,训练时收敛快,速度也较VGG有所提升。 + + +## 数据准备 + +### 数据介绍与下载 + +在本教程中,我们使用[CIFAR-10]()数据集训练一个卷积神经网络。CIFAR-10数据集包含60,000张32x32的彩色图片,10个类别,每个类包含6,000张。其中50,000张图片作为训练集,10000张作为测试集。下图从每个类别中随机抽取了10张图片,展示了所有的类别。 + +

+
+图3. CIFAR-10数据集 +

+ + +执行下面命令下载数据,同时,会基于训练集计算图像均值,在训练阶段,输入数据会基于该均值做预处理,再传输给系统。 + +```bash +./data/get_data.sh +``` + +### 数据提供器 + +我们使用Python接口传递数据给系统,下面 `dataprovider.py` 针对Cfiar-10数据给出了完整示例。 + +`initializer` 函数进行dataprovider的初始化,这里加载图像的均值,定义了输入image和label两个字段的类型。 + +`process` 函数将数据逐条传输给系统,在图像分类做可以完整数据扰动操作,再传输给PaddlePaddle。这里将原始图片减去均值后传输给系统。 + + +```python +def initializer(settings, mean_path, is_train, **kwargs): + settings.is_train = is_train + settings.input_size = 3 * 32 * 32 + settings.mean = np.load(mean_path)['mean'] + settings.input_types = { + 'image': dense_vector(settings.input_size), + 'label': integer_value(10) + } + + +@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, file_list): + with open(file_list, 'r') as fdata: + for fname in fdata: + fo = open(fname.strip(), 'rb') + batch = cPickle.load(fo) + fo.close() + images = batch['data'] + labels = batch['labels'] + for im, lab in zip(images, labels): + im = im - settings.mean + yield { + 'image': im.astype('float32'), + 'label': int(lab) + } +``` + +## 模型配置说明 + +### 数据定义 + +在模型配置中,定义通过 `define_py_data_sources2` 从 dataprovider 中读入数据, 其中 args 指定均值文件的路径。 + +```python +define_py_data_sources2( + train_list='data/train.list', + test_list='data/test.list', + module='dataprovider', + obj='process', + args={'mean_path': 'data/mean.meta'}) +``` + +### 算法配置 + +在模型配置中,通过 `seetings` 设置训练使用的优化算法,这里指定batch size 、初始学习率、momentum以及L2正则。 + +```python +settings( + batch_size=128, + learning_rate=0.1 / 128.0, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * 128)) +``` + +### 模型结构 + +在模型概览部分已经介绍了VGG和ResNet模型,本教程中我们提供了这两个模型的网络配置。 +下面是VGG模型结构,在Cifar-10数据集上,卷积部分引入了Batch Norm和Dropout操作。 + +1. 首先预定义了一组卷积网络,即conv_block, 所使用的 `img_conv_group` 是我们预定义的一个模块,由若干组 `Conv->BatchNorm->Relu->Dropout` 和 一组 `Pooling` 组成,其中卷积操作采用3x3的卷积核。下面定义中根据 groups 决定是几次连续的卷积操作。 + +2. 五组卷积操作,即 5个conv_block。 第一、二组采用两次连续的卷积操作。第三、四、五组采用三次连续的卷积操作。 + +3. 由两层512维的全连接网络和一个分类层组成。 + + +```python +def vgg_bn_drop(input, num_channels): + def conv_block(ipt, num_filter, groups, dropouts, num_channels_=None): + return img_conv_group( + input=ipt, + num_channels=num_channels_, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act=ReluActivation(), + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type=MaxPooling()) + + tmp = conv_block(input, 64, 2, [0.3, 0], num_channels) + tmp = conv_block(tmp, 128, 2, [0.4, 0]) + tmp = conv_block(tmp, 256, 3, [0.4, 0.4, 0]) + tmp = conv_block(tmp, 512, 3, [0.4, 0.4, 0]) + tmp = conv_block(tmp, 512, 3, [0.4, 0.4, 0]) + + tmp = dropout_layer(input=tmp, dropout_rate=0.5) + tmp = fc_layer( + input=tmp, + size=512, + act=LinearActivation()) + tmp = batch_norm_layer(input=tmp, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + tmp = fc_layer( + input=tmp, + size=512, + act=LinearActivation()) + tmp = fc_layer(input=tmp, size=10, act=SoftmaxActivation()) + return tmp + +``` + +## 模型训练 + +``` bash +sh train.sh +``` + +执行脚本 train.sh 进行模型训练, 其中指定了总共需要执行500个pass。 + +在第一行中我们载入用于定义网络的函数。 + +配置创建完毕后,可以运行脚本train.sh来训练模型。 + + +```bash +#cfg=models/resnet.py +cfg=models/vgg.py +output=./output +log=train.log + +paddle train \ + --config=$cfg \ + --use_gpu=true \ + --trainer_count=1 \ + --log_period=100 \ + --save_dir=$output \ + 2>&1 | tee $log +``` + +- `--config=$cfg` : 指定配置文件,默认是 `models/vgg.py`。 +- `--use_gpu=true` : 指定使用GPU训练,若使用CPU,设置为false。 +- `--trainer_count=1` : 指定线程个数或GPU个数。 +- `--log_period=100` : 指定日志打印的batch间隔。 +- `--save_dir=$output` : 指定模型存储路径。 + +一轮训练log示例如下所示,经过1个pass, 训练集上平均error为classification_error_evaluator=0.79958 ,测试集上平均error为 classification_error_evaluator=0.7858 。 + +```text +I1226 12:33:20.257822 25576 TrainerInternal.cpp:165] Batch=300 samples=38400 AvgCost=2.07708 CurrentCost=1.96158 Eval: classification_error_evaluator=0.81151 CurrentEval: classification_error_evaluator=0.789297 +.........I1226 12:33:37.720484 25576 TrainerInternal.cpp:181] Pass=0 Batch=391 samples=50000 AvgCost=2.03348 Eval: classification_error_evaluator=0.79958 +I1226 12:33:42.413450 25576 Tester.cpp:115] Test samples=10000 cost=1.99246 Eval: classification_error_evaluator=0.7858 +``` + + + +下图是训练的分类错误率曲线图: + +
![Training and testing curves.](image/plot.png)
+ +## 模型应用 + +在训练完成后,模型及参数会被保存在路径`./cifar_vgg_model/pass-%05d`下。例如第300个pass的模型会被保存在`./cifar_vgg_model/pass-00299`。 + +要对一个图片的进行分类预测,我们可以使用`predict.sh`,该脚本将输出预测分类的标签: + +``` +sh predict.sh +``` + +predict.sh: +``` +model=cifar_vgg_model/pass-00299/ +image=data/cifar-out/test/airplane/seaplane_s_000978.png +use_gpu=1 +python prediction.py $model $image $use_gpu +``` + +## 练习 +在CUB-200数据集上使用VGG模型训练一个鸟类图片分类模型。相关的鸟类数据集可以从如下地址下载,其中包含了200种鸟类的照片(主要来自北美洲)。 + + + + +## 细节探究 +### 卷积神经网络 +卷积神经网络是一种使用卷积层的前向神经网络,很适合构建用于理解图片内容的模型。一个典型的神经网络如下图所示: + +![Convolutional Neural Network](image/lenet.png) + +一个卷积神经网络包含如下层: + +- 卷积层:通过卷积操作从图片或特征图中提取特征 +- 池化层:使用max-pooling对特征图下采样 +- 全连接层:使输入层到隐藏层的神经元是全部连接的。 + +卷积神经网络在图片分类上有着惊人的性能,这是因为它发掘出了图片的两类重要信息:局部关联性质和空间不变性质。通过交替使用卷积和池化处理, 卷积神经网络能够很好的表示这两类信息。 + +关于如何定义网络中的层,以及如何在层之间进行连接,请参考Layer文档。 + + +## 参考文献 + +[1]. K. Chatfield, K. Simonyan, A. Vedaldi, A. Zisserman. Return of the Devil in the Details: Delving Deep into Convolutional Nets. BMVC, 2014。 +[2]. K. He, X. Zhang, S. Ren, J. Sun. Deep Residual Learning for Image Recognition. CVPR 2016. diff --git a/image_classification/data/cifar10.py b/image_classification/data/cifar10.py new file mode 100755 index 0000000000000000000000000000000000000000..965757296d4a5bbf9b70c27ae36b4565eaa3c60a --- /dev/null +++ b/image_classification/data/cifar10.py @@ -0,0 +1,43 @@ +import os +import numpy as np +import cPickle + +DATA = "cifar-10-batches-py" +CHANNEL = 3 +HEIGHT = 32 +WIDTH = 32 + +def create_mean(dataset): + if not os.path.isfile("mean.meta"): + mean = np.zeros(CHANNEL * HEIGHT * WIDTH) + num = 0 + for f in dataset: + batch = np.load(f) + mean += batch['data'].sum(0) + num += len(batch['data']) + mean /= num + print mean.size + data = {"mean": mean, "size": mean.size} + cPickle.dump(data, open("mean.meta", 'w'), protocol=cPickle.HIGHEST_PROTOCOL) + + +def create_data(): + train_set = [DATA + "/data_batch_%d" % (i + 1) for i in xrange(0,5)] + test_set = [DATA + "/test_batch"] + + # create mean values + create_mean(train_set) + + # create dataset lists + if not os.path.isfile("train.txt"): + train = ["data/" + i for i in train_set] + open("train.txt", "w").write("\n".join(train)) + open("train.list", "w").write("\n".join(["data/train.txt"])) + + if not os.path.isfile("text.txt"): + test = ["data/" + i for i in test_set] + open("test.txt", "w").write("\n".join(test)) + open("test.list", "w").write("\n".join(["data/test.txt"])) + +if __name__ == '__main__': + create_data() diff --git a/image_classification/data/get_data.sh b/image_classification/data/get_data.sh new file mode 100755 index 0000000000000000000000000000000000000000..519521ffc5b8e3b5100392c32bf5d6b494d7907f --- /dev/null +++ b/image_classification/data/get_data.sh @@ -0,0 +1,19 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +tar zxf cifar-10-python.tar.gz +rm cifar-10-python.tar.gz + +python cifar10.py diff --git a/image_classification/dataprovider.py b/image_classification/dataprovider.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ba430c63539e17214bea5102bb91b90b226519 --- /dev/null +++ b/image_classification/dataprovider.py @@ -0,0 +1,44 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cPickle + +from paddle.trainer.PyDataProvider2 import * + +def initializer(settings, mean_path, is_train, **kwargs): + settings.is_train = is_train + settings.input_size = 3 * 32 * 32 + settings.mean = np.load(mean_path)['mean'] + settings.input_types = { + 'image': dense_vector(settings.input_size), + 'label': integer_value(10) + } + + +@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, file_list): + with open(file_list, 'r') as fdata: + for fname in fdata: + fo = open(fname.strip(), 'rb') + batch = cPickle.load(fo) + fo.close() + images = batch['data'] + labels = batch['labels'] + for im, lab in zip(images, labels): + im = im - settings.mean + yield { + 'image': im.astype('float32'), + 'label': int(lab) + } diff --git a/image_classification/image/cifar.png b/image_classification/image/cifar.png new file mode 100644 index 0000000000000000000000000000000000000000..f54a0c58837cb3385b32dc57d02cec92666ef0f1 Binary files /dev/null and b/image_classification/image/cifar.png differ diff --git a/image_classification/image/flowers.png b/image_classification/image/flowers.png new file mode 100644 index 0000000000000000000000000000000000000000..04245cef60fe7126ae4c92ba8085273965078bee Binary files /dev/null and b/image_classification/image/flowers.png differ diff --git a/image_classification/image/image_classification.png b/image_classification/image/image_classification.png new file mode 100644 index 0000000000000000000000000000000000000000..14f255805081c1b4fab27eaf336fd389fa93ca19 Binary files /dev/null and b/image_classification/image/image_classification.png differ diff --git a/image_classification/image/lenet.png b/image_classification/image/lenet.png new file mode 100644 index 0000000000000000000000000000000000000000..1e6f2b32bad797f3fccb929c72a121fc935b0cbb Binary files /dev/null and b/image_classification/image/lenet.png differ diff --git a/image_classification/image/plot.png b/image_classification/image/plot.png new file mode 100644 index 0000000000000000000000000000000000000000..a31f99791c670e18bb8c62b7604ec8cb0284ffb4 Binary files /dev/null and b/image_classification/image/plot.png differ diff --git a/image_classification/image/variations.png b/image_classification/image/variations.png new file mode 100644 index 0000000000000000000000000000000000000000..2134587e4b4be15a92c4b400d1b67f021fa28fa6 Binary files /dev/null and b/image_classification/image/variations.png differ diff --git a/image_classification/models/resnet.py b/image_classification/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ba732cf810e348343a7a887d87bbe7ad0956790e --- /dev/null +++ b/image_classification/models/resnet.py @@ -0,0 +1,134 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + ch_in=None, + active_type=ReluActivation()): + tmp = img_conv_layer( + input=input, + filter_size=filter_size, + num_channels=ch_in, + num_filters=ch_out, + stride=stride, + padding=padding, + act=LinearActivation(), + bias_attr=False) + return batch_norm_layer(input=tmp, act=active_type) + + +def shortcut(ipt, n_in, n_out, stride): + if n_in != n_out: + return conv_bn_layer(ipt, n_out, 1, stride=stride, LinearActivation()) + else: + return ipt + +def basicblock(ipt, ch_out, stride): + ch_in = ipt.num_filter + tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, LinearActivation()) + short = shortcut(ipt, ch_in, ch_out, stride) + return addto_layer(input=[input, short], act=ReluActivation()) + +def bottleneck(ipt, ch_out, stride): + ch_in = ipt.num_filter + tmp = conv_bn_layer(ipt, ch_out, 1, stride, 0) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1) + tmp = conv_bn_layer(tmp, ch_out * 4, 1, 1, 0, LinearActivation()) + short = shortcut(ipt, ch_in, ch_out, stride) + return addto_layer(input=[input, short], act=ReluActivation()) + +def layer_warp(block_func, ipt, features, count, stride): + tmp = block_func(tmp, features, stride) + for i in range(1, count): + tmp = block_func(tmp, features, 1) + return tmp + +def resnet_imagenet(ipt, depth=50): + cfg = {18 : ([2,2,2,1], basicblock), + 34 : ([3,4,6,3], basicblock), + 50 : ([3,4,6,3], bottleneck), + 101: ([3,4,23,3], bottleneck), + 152: ([3,8,36,3], bottleneck)} + stages, block_func = cfg[depth] + tmp = conv_bn_layer(ipt, + ch_in=3, + ch_out=64, + filter_size=7, + stride=2, + padding=3) + tmp = img_pool_layer(input=tmp, pool_size=3, stride=2) + tmp = layer_warp(block_func, tmp, 64, stages[0], 1) + tmp = layer_warp(block_func, tmp, 128, stages[1], 2) + tmp = layer_warp(block_func, tmp, 256, stages[2], 2) + tmp = layer_warp(block_func, tmp, 512, stages[3], 2) + tmp = img_pool_layer(input=tmp, + pool_size=7, + stride=1, + pool_type=AvgPooling()) + + tmp = fc_layer(input=tmp, size=1000, act=SoftmaxActivation()) + return tmp + +def resnet_cifar10(ipt, depth=56): + assert((depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110, 1202') + n = (depth - 2) / 6 + nStages = {16, 64, 128} + tmp = conv_bn_layer(ipt, + ch_in=3, + ch_out=16, + filter_size=3, + stride=1, + padding=1) + tmp = layer_warp(basicblock, tmp, 16, n) + tmp = layer_warp(basicblock, tmp, 32, n, 2) + tmp = layer_warp(basicblock, tmp, 64, n, 2) + tmp = img_pool_layer(input=tmp, + pool_size=8, + stride=1, + pool_type=AvgPooling()) + tmp = fc_layer(input=tmp, size=10, act=SoftmaxActivation()) + return tmp + + +is_predict = get_config_arg("is_predict", bool, False) +if not is_predict: + args = {'meta': 'data/mean.meta'} + define_py_data_sources2( + train_list='data/train.list', + test_list='data/test.list', + module='dataprovider', + obj='process', + args=args) + +settings( + batch_size=128, + learning_rate=0.1 / 128.0, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * 128)) + +data_size = 3 * 32 * 32 +class_num = 10 +data = data_layer(name='image', size=data_size) +out = resnet_cifar10(data, depth=50) +if not is_predict: + lbl = data_layer(name="label", size=class_num) + outputs(classification_cost(input=out, label=lbl)) +else: + outputs(out) diff --git a/image_classification/models/vgg.py b/image_classification/models/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..25995c6633ab7ad55121eab90ebeaba5f9d4de6d --- /dev/null +++ b/image_classification/models/vgg.py @@ -0,0 +1,75 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +def vgg_bn_drop(input, num_channels): + def conv_block(ipt, num_filter, groups, dropouts, num_channels_=None): + return img_conv_group( + input=ipt, + num_channels=num_channels_, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act=ReluActivation(), + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type=MaxPooling()) + + tmp = conv_block(input, 64, 2, [0.3, 0], num_channels) + tmp = conv_block(tmp, 128, 2, [0.4, 0]) + tmp = conv_block(tmp, 256, 3, [0.4, 0.4, 0]) + tmp = conv_block(tmp, 512, 3, [0.4, 0.4, 0]) + tmp = conv_block(tmp, 512, 3, [0.4, 0.4, 0]) + + tmp = dropout_layer(input=tmp, dropout_rate=0.5) + tmp = fc_layer( + input=tmp, + size=512, + act=LinearActivation()) + tmp = batch_norm_layer(input=tmp, + act=ReluActivation(), + layer_attr=ExtraAttr(drop_rate=0.5)) + tmp = fc_layer( + input=tmp, + size=512, + act=LinearActivation()) + tmp = fc_layer(input=tmp, size=10, act=SoftmaxActivation()) + return tmp + +is_predict = get_config_arg("is_predict", bool, False) +if not is_predict: + define_py_data_sources2( + train_list='data/train.list', + test_list='data/test.list', + module='dataprovider', + obj='process', + args={'mean_path': 'data/mean.meta'}) + +settings( + batch_size=128, + learning_rate=0.1 / 128.0, + learning_method=MomentumOptimizer(0.9), + regularization=L2Regularization(0.0005 * 128)) + +data_size = 3 * 32 * 32 +class_num = 10 +data = data_layer(name='image', size=data_size) +out = vgg_bn_drop(data, 3) +if not is_predict: + lbl = data_layer(name="label", size=class_num) + outputs(classification_cost(input=out, label=lbl)) +else: + outputs(out) diff --git a/image_classification/predict.sh b/image_classification/predict.sh new file mode 100755 index 0000000000000000000000000000000000000000..3bba2a94f70d3a687d8526bf10ba0554f9a1bd38 --- /dev/null +++ b/image_classification/predict.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +model=output/pass-00299/ +image=data/cifar-out/test/airplane/seaplane_s_000978.png +use_gpu=1 +python prediction.py $model $image $use_gpu diff --git a/image_classification/prediction.py b/image_classification/prediction.py new file mode 100755 index 0000000000000000000000000000000000000000..519b864221c490de71fbeeebcc607935540db208 --- /dev/null +++ b/image_classification/prediction.py @@ -0,0 +1,159 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, sys +import numpy as np +import logging +from PIL import Image +from optparse import OptionParser + +import paddle.utils.image_util as image_util + +from py_paddle import swig_paddle, DataProviderConverter +from paddle.trainer.PyDataProvider2 import dense_vector +from paddle.trainer.config_parser import parse_config + +logging.basicConfig( + format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s') +logging.getLogger().setLevel(logging.INFO) + + +class ImageClassifier(): + def __init__(self, + train_conf, + use_gpu=True, + model_dir=None, + resize_dim=None, + crop_dim=None, + mean_file=None, + oversample=False, + is_color=True): + """ + train_conf: network configure. + model_dir: string, directory of model. + resize_dim: int, resized image size. + crop_dim: int, crop size. + mean_file: string, image mean file. + oversample: bool, oversample means multiple crops, namely five + patches (the four corner patches and the center + patch) as well as their horizontal reflections, + ten crops in all. + """ + self.train_conf = train_conf + self.model_dir = model_dir + if model_dir is None: + self.model_dir = os.path.dirname(train_conf) + + self.resize_dim = resize_dim + self.crop_dims = [crop_dim, crop_dim] + self.oversample = oversample + self.is_color = is_color + + self.transformer = image_util.ImageTransformer(is_color=is_color) + self.transformer.set_transpose((2, 0, 1)) + + self.mean_file = mean_file + mean = np.load(self.mean_file)['data_mean'] + mean = mean.reshape(3, self.crop_dims[0], self.crop_dims[1]) + self.transformer.set_mean(mean) # mean pixel + gpu = 1 if use_gpu else 0 + conf_args = "is_test=1,use_gpu=%d,is_predict=1" % (gpu) + conf = parse_config(train_conf, conf_args) + swig_paddle.initPaddle("--use_gpu=%d" % (gpu)) + self.network = swig_paddle.GradientMachine.createFromConfigProto( + conf.model_config) + assert isinstance(self.network, swig_paddle.GradientMachine) + self.network.loadParameters(self.model_dir) + + data_size = 3 * self.crop_dims[0] * self.crop_dims[1] + slots = [dense_vector(data_size)] + self.converter = DataProviderConverter(slots) + + def get_data(self, img_path): + """ + 1. load image from img_path. + 2. resize or oversampling. + 3. transformer data: transpose, sub mean. + return K x H x W ndarray. + img_path: image path. + """ + image = image_util.load_image(img_path, self.is_color) + if self.oversample: + # image_util.resize_image: short side is self.resize_dim + image = image_util.resize_image(image, self.resize_dim) + image = np.array(image) + input = np.zeros( + (1, image.shape[0], image.shape[1], 3), dtype=np.float32) + input[0] = image.astype(np.float32) + input = image_util.oversample(input, self.crop_dims) + else: + image = image.resize(self.crop_dims, Image.ANTIALIAS) + input = np.zeros( + (1, self.crop_dims[0], self.crop_dims[1], 3), dtype=np.float32) + input[0] = np.array(image).astype(np.float32) + + data_in = [] + for img in input: + img = self.transformer.transformer(img).flatten() + data_in.append([img.tolist()]) + return data_in + + def forward(self, input_data): + in_arg = self.converter(input_data) + return self.network.forwardTest(in_arg) + + def forward(self, data, output_layer): + """ + input_data: py_paddle input data. + output_layer: specify the name of probability, namely the layer with + softmax activation. + return: the predicting probability of each label. + """ + input = self.converter(data) + self.network.forwardTest(input) + output = self.network.getLayerOutputs(output_layer) + # For oversampling, average predictions across crops. + # If not, the shape of output[name]: (1, class_number), + # the mean is also applicable. + return output[output_layer].mean(0) + + def predict(self, image=None, output_layer=None): + assert isinstance(image, basestring) + assert isinstance(output_layer, basestring) + data = self.get_data(image) + prob = self.forward(data, output_layer) + lab = np.argsort(-prob) + logging.info("Label of %s is: %d", image, lab[0]) + + +if __name__ == '__main__': + image_size = 32 + crop_size = 32 + multi_crop = True + config = "vgg_16_cifar.py" + output_layer = "__fc_layer_1__" + mean_path = "data/batches.meta" + model_path = sys.argv[1] + image = sys.argv[2] + use_gpu = bool(int(sys.argv[3])) + + obj = ImageClassifier( + train_conf=config, + model_dir=model_path, + resize_dim=image_size, + crop_dim=crop_size, + mean_file=mean_path, + use_gpu=use_gpu, + oversample=multi_crop) + obj.predict(image, output_layer) diff --git a/image_classification/train.sh b/image_classification/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..5fda4f92ee6a40328fc3924dc78c6f963550dc9e --- /dev/null +++ b/image_classification/train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +#config=models/resnet.py +config=models/vgg.py +output=./output +log=train.log + +paddle train \ + --config=$config \ + --use_gpu=1 \ + --trainer_count=4 \ + --log_period=100 \ + --num_passes=300 \ + --save_dir=$output \ + 2>&1 | tee $log