From 8848164129b0e38898c7752915880d38f153edec Mon Sep 17 00:00:00 2001 From: wwhu Date: Thu, 1 Jun 2017 15:29:34 +0800 Subject: [PATCH] add doc and reorginize net output --- image_classification/README.md | 183 +++++++++++++++++++++++++++++- image_classification/alexnet.py | 6 +- image_classification/googlenet.py | 91 +++++++++++++-- image_classification/resnet.py | 12 +- image_classification/train.py | 18 +-- image_classification/vgg.py | 18 +-- 6 files changed, 290 insertions(+), 38 deletions(-) diff --git a/image_classification/README.md b/image_classification/README.md index a0990367..0010fe5b 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1 +1,182 @@ -TBD +图像分类 +======================= + +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 + +## 数据格式 +reader.py定义了数据格式,它读取一个图像列表文件,并从中解析出图像路径和类别标签。 + +图像列表文件是一个文本文件,其中每一行由一个图像路径和类别标签构成,二者以跳格符(Tab)隔开。类别标签用整数表示,其最小值为0。下面给出一个图像列表文件的片段示例: + +``` +dataset_100/train_images/n03982430_23191.jpeg 1 +dataset_100/train_images/n04461696_23653.jpeg 7 +dataset_100/train_images/n02441942_3170.jpeg 8 +dataset_100/train_images/n03733281_31716.jpeg 2 +dataset_100/train_images/n03424325_240.jpeg 0 +dataset_100/train_images/n02643566_75.jpeg 8 +``` + +## 训练模型 + +### 初始化 + +在初始化阶段需要导入所用的包,并对PaddlePaddle进行初始化。 + +```python +import gzip +import paddle.v2 as paddle +import reader +import vgg +import resnet +import alexnet +import googlenet +import argparse +import os + +# PaddlePaddle init +paddle.init(use_gpu=False, trainer_count=1) +``` + +### 定义参数和输入 + +设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`。 + +```python +DATA_DIM = 3 * 224 * 224 +CLASS_DIM = 100 +BATCH_SIZE = 128 + +image = paddle.layer.data( + name="image", type=paddle.data_type.dense_vector(DATA_DIM)) +lbl = paddle.layer.data( + name="label", type=paddle.data_type.integer_value(CLASS_DIM)) +``` + +### 获得所用模型 + +这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。 + +1. 使用AlexNet模型 + +指定输入层`image`和类别数目`CLASS_DIM`后,可以通过下面的代码得到AlexNet的Softmax层。 + +```python +out = alexnet.alexnet(image, class_dim=CLASS_DIM) +``` + +2. 使用VGG模型 + +根据层数的不同,VGG分为VGG13、VGG16和VGG19。使用VGG16模型的代码如下: + +```python +out = vgg.vgg16(image, class_dim=CLASS_DIM) +``` + +类似地,VGG13和VGG19可以分别通过`vgg.vgg13`和`vgg.vgg19`方法获得。 + +3. 使用GoogLeNet模型 + +GoogLeNet在训练阶段使用两个辅助的分类器强化梯度信息并进行额外的正则化。因此`googlenet.googlenet`共返回三个Softmax层,如下面的代码所示: + +```python +out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) +loss1 = paddle.layer.cross_entropy_cost( + input=out1, label=lbl, coeff=0.3) +paddle.evaluator.classification_error(input=out1, label=lbl) +loss2 = paddle.layer.cross_entropy_cost( + input=out2, label=lbl, coeff=0.3) +paddle.evaluator.classification_error(input=out2, label=lbl) +extra_layers = [loss1, loss2] +``` + +对于两个辅助的输出,这里分别对其计算损失函数并评价错误率,然后将损失作为后文SGD的extra_layers。 + +4. 使用ResNet模型 + +ResNet模型可以通过下面的代码获取: + +```python +out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) +``` + +### 定义损失函数 + +```python +cost = paddle.layer.classification_cost(input=out, label=lbl) +``` + +### 创建参数和优化方法 + +```python +# Create parameters +parameters = paddle.parameters.create(cost) + +# Create optimizer +optimizer = paddle.optimizer.Momentum( + momentum=0.9, + regularization=paddle.optimizer.L2Regularization(rate=0.0005 * + BATCH_SIZE), + learning_rate=0.001 / BATCH_SIZE, + learning_rate_decay_a=0.1, + learning_rate_decay_b=128000 * 35, + learning_rate_schedule="discexp", ) +``` + +### 定义数据读取方法和事件处理程序 + +读取数据时需要分别指定训练集和验证集的图像列表文件,这里假设这两个文件分别为`train.list`和`val.list`。 + +```python +train_reader = paddle.batch( + paddle.reader.shuffle( + reader.test_reader('train.list'), + buf_size=1000), + batch_size=BATCH_SIZE) +test_reader = paddle.batch( + reader.train_reader('val.list'), + batch_size=BATCH_SIZE) + +# End batch and end pass event handler +def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 1 == 0: + print "\nPass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + if isinstance(event, paddle.event.EndPass): + with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: + parameters.to_tar(f) + + result = trainer.test(reader=test_reader) + print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) +``` + +### 定义训练方法 + +对于AlexNet、VGG和ResNet,可以按下面的代码定义训练方法: + +```python +# Create trainer +trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + update_equation=optimizer) +``` + +GoogLeNet有两个额外的输出层,因此需要指定`extra_layers`,如下所示: + +```python +# Create trainer +trainer = paddle.trainer.SGD( + cost=cost, + parameters=parameters, + update_equation=optimizer, + extra_layers=extra_layers) +``` + +### 开始训练 + +```python +trainer.train( + reader=train_reader, num_passes=200, event_handler=event_handler) +``` diff --git a/image_classification/alexnet.py b/image_classification/alexnet.py index eaa7a3dc..8aa53814 100644 --- a/image_classification/alexnet.py +++ b/image_classification/alexnet.py @@ -3,7 +3,7 @@ import paddle.v2 as paddle __all__ = ['alexnet'] -def alexnet(input): +def alexnet(input, class_dim=100): conv1 = paddle.layer.img_conv( input=input, filter_size=11, @@ -45,4 +45,6 @@ def alexnet(input): act=paddle.activation.Relu(), layer_attr=paddle.attr.Extra(drop_rate=0.5)) - return fc2 + out = paddle.layer.fc( + input=fc2, size=class_dim, act=paddle.activation.Softmax()) + return out diff --git a/image_classification/googlenet.py b/image_classification/googlenet.py index 60cfa9d4..2e4153cc 100644 --- a/image_classification/googlenet.py +++ b/image_classification/googlenet.py @@ -53,7 +53,69 @@ def inception(name, input, channels, filter1, filter3R, filter3, filter5R, return cat -def googlenet(input): +def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, + filter5, proj): + cov1 = paddle.layer.img_conv( + name=name + '_1', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter1, + stride=1, + padding=0) + + cov3r = paddle.layer.img_conv( + name=name + '_3r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter3R, + stride=1, + padding=0) + cov3 = paddle.layer.img_conv( + name=name + '_3', + input=cov3r, + filter_size=3, + num_filters=filter3, + stride=1, + padding=1) + + cov5r = paddle.layer.img_conv( + name=name + '_5r', + input=input, + filter_size=1, + num_channels=channels, + num_filters=filter5R, + stride=1, + padding=0) + cov5 = paddle.layer.img_conv( + name=name + '_5', + input=cov5r, + filter_size=5, + num_filters=filter5, + stride=1, + padding=2) + + pool1 = paddle.layer.img_pool( + name=name + '_max', + input=input, + pool_size=3, + num_channels=channels, + stride=1, + padding=1) + covprj = paddle.layer.img_conv( + name=name + '_proj', + input=pool1, + filter_size=1, + num_filters=proj, + stride=1, + padding=0) + + cat = paddle.layer.concat(name=name, input=[cov1, cov3, cov5, covprj]) + return cat + + +def googlenet(input, class_dim=100): # stage 1 conv1 = paddle.layer.img_conv( name="conv1", @@ -85,23 +147,23 @@ def googlenet(input): name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2) # stage 3 - ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) - ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) + ince3a = inception2("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32) + ince3b = inception2("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64) pool3 = paddle.layer.img_pool( name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2) # stage 4 - ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) - ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) - ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) - ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) - ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) + ince4a = inception2("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64) + ince4b = inception2("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64) + ince4c = inception2("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64) + ince4d = inception2("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64) + ince4e = inception2("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128) pool4 = paddle.layer.img_pool( name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2) # stage 5 - ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) - ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) + ince5a = inception2("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128) + ince5b = inception2("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128) pool5 = paddle.layer.img_pool( name="pool5", input=ince5b, @@ -114,6 +176,9 @@ def googlenet(input): layer_attr=paddle.attr.Extra(drop_rate=0.4), act=paddle.activation.Linear()) + out = paddle.layer.fc( + input=dropout, size=class_dim, act=paddle.activation.Softmax()) + # fc for output 1 pool_o1 = paddle.layer.img_pool( name="pool_o1", @@ -135,6 +200,8 @@ def googlenet(input): size=1024, layer_attr=paddle.attr.Extra(drop_rate=0.7), act=paddle.activation.Relu()) + out1 = paddle.layer.fc( + input=fc_o1, size=class_dim, act=paddle.activation.Softmax()) # fc for output 2 pool_o2 = paddle.layer.img_pool( @@ -157,5 +224,7 @@ def googlenet(input): size=1024, layer_attr=paddle.attr.Extra(drop_rate=0.7), act=paddle.activation.Relu()) + out2 = paddle.layer.fc( + input=fc_o2, size=class_dim, act=paddle.activation.Softmax()) - return dropout, fc_o1, fc_o2 + return out, out1, out2 diff --git a/image_classification/resnet.py b/image_classification/resnet.py index 1da44aad..7ef551b3 100644 --- a/image_classification/resnet.py +++ b/image_classification/resnet.py @@ -57,7 +57,7 @@ def layer_warp(block_func, input, features, count, stride): return conv -def resnet_imagenet(input, depth=50): +def resnet_imagenet(input, depth=50, class_dim=100): cfg = { 18: ([2, 2, 2, 1], basicblock), 34: ([3, 4, 6, 3], basicblock), @@ -75,10 +75,12 @@ def resnet_imagenet(input, depth=50): res4 = layer_warp(block_func, res3, 512, stages[3], 2) pool2 = paddle.layer.img_pool( input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg()) - return pool2 + out = paddle.layer.fc( + input=pool2, size=class_dim, act=paddle.activation.Softmax()) + return out -def resnet_cifar10(input, depth=32): +def resnet_cifar10(input, depth=32, class_dim=10): # depth should be one of 20, 32, 44, 56, 110, 1202 assert (depth - 2) % 6 == 0 n = (depth - 2) / 6 @@ -90,4 +92,6 @@ def resnet_cifar10(input, depth=32): res3 = layer_warp(basicblock, res2, 64, n, 2) pool = paddle.layer.img_pool( input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) - return pool + out = paddle.layer.fc( + input=pool, size=class_dim, act=paddle.activation.Softmax()) + return out diff --git a/image_classification/train.py b/image_classification/train.py index a8817c60..36135616 100755 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -35,31 +35,25 @@ def main(): extra_layers = None if args.model == 'alexnet': - net = alexnet.alexnet(image) + out = alexnet.alexnet(image, class_dim=CLASS_DIM) elif args.model == 'vgg13': - net = vgg.vgg13(image) + out = vgg.vgg13(image, class_dim=CLASS_DIM) elif args.model == 'vgg16': - net = vgg.vgg16(image) + out = vgg.vgg16(image, class_dim=CLASS_DIM) elif args.model == 'vgg19': - net = vgg.vgg19(image) + out = vgg.vgg19(image, class_dim=CLASS_DIM) elif args.model == 'resnet': - net = resnet.resnet_imagenet(image) + out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': - net, fc_o1, fc_o2 = googlenet.googlenet(image) - out1 = paddle.layer.fc( - input=fc_o1, size=CLASS_DIM, act=paddle.activation.Softmax()) + out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM) loss1 = paddle.layer.cross_entropy_cost( input=out1, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out1, label=lbl) - out2 = paddle.layer.fc( - input=fc_o2, size=CLASS_DIM, act=paddle.activation.Softmax()) loss2 = paddle.layer.cross_entropy_cost( input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] - out = paddle.layer.fc( - input=net, size=CLASS_DIM, act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=out, label=lbl) # Create parameters diff --git a/image_classification/vgg.py b/image_classification/vgg.py index e21504ab..b272320b 100644 --- a/image_classification/vgg.py +++ b/image_classification/vgg.py @@ -17,7 +17,7 @@ import paddle.v2 as paddle __all__ = ['vgg13', 'vgg16', 'vgg19'] -def vgg(input, nums): +def vgg(input, nums, class_dim=100): def conv_block(input, num_filter, groups, num_channels=None): return paddle.networks.img_conv_group( input=input, @@ -48,19 +48,21 @@ def vgg(input, nums): size=fc_dim, act=paddle.activation.Relu(), layer_attr=paddle.attr.Extra(drop_rate=0.5)) - return fc2 + out = paddle.layer.fc( + input=fc2, size=class_dim, act=paddle.activation.Softmax()) + return out -def vgg13(input): +def vgg13(input, class_dim=100): nums = [2, 2, 2, 2, 2] - return vgg(input, nums) + return vgg(input, nums, class_dim) -def vgg16(input): +def vgg16(input, class_dim=100): nums = [2, 2, 3, 3, 3] - return vgg(input, nums) + return vgg(input, nums, class_dim) -def vgg19(input): +def vgg19(input, class_dim=100): nums = [2, 2, 4, 4, 4] - return vgg(input, nums) + return vgg(input, nums, class_dim) -- GitLab