提交 88481641 编写于 作者: W wwhu

add doc and reorginize net output

上级 6dd3895e
TBD
图像分类
=======================
这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)
## 数据格式
reader.py定义了数据格式,它读取一个图像列表文件,并从中解析出图像路径和类别标签。
图像列表文件是一个文本文件,其中每一行由一个图像路径和类别标签构成,二者以跳格符(Tab)隔开。类别标签用整数表示,其最小值为0。下面给出一个图像列表文件的片段示例:
```
dataset_100/train_images/n03982430_23191.jpeg 1
dataset_100/train_images/n04461696_23653.jpeg 7
dataset_100/train_images/n02441942_3170.jpeg 8
dataset_100/train_images/n03733281_31716.jpeg 2
dataset_100/train_images/n03424325_240.jpeg 0
dataset_100/train_images/n02643566_75.jpeg 8
```
## 训练模型
### 初始化
在初始化阶段需要导入所用的包,并对PaddlePaddle进行初始化。
```python
import gzip
import paddle.v2 as paddle
import reader
import vgg
import resnet
import alexnet
import googlenet
import argparse
import os
# PaddlePaddle init
paddle.init(use_gpu=False, trainer_count=1)
```
### 定义参数和输入
设置算法参数(如数据维度、类别数目和batch size等参数),定义数据输入层`image`和类别标签`lbl`
```python
DATA_DIM = 3 * 224 * 224
CLASS_DIM = 100
BATCH_SIZE = 128
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(DATA_DIM))
lbl = paddle.layer.data(
name="label", type=paddle.data_type.integer_value(CLASS_DIM))
```
### 获得所用模型
这里可以选择使用AlexNet、VGG、GoogLeNet和ResNet模型中的一个模型进行图像分类。通过调用相应的方法可以获得网络最后的Softmax层。
1. 使用AlexNet模型
指定输入层`image`和类别数目`CLASS_DIM`后,可以通过下面的代码得到AlexNet的Softmax层。
```python
out = alexnet.alexnet(image, class_dim=CLASS_DIM)
```
2. 使用VGG模型
根据层数的不同,VGG分为VGG13、VGG16和VGG19。使用VGG16模型的代码如下:
```python
out = vgg.vgg16(image, class_dim=CLASS_DIM)
```
类似地,VGG13和VGG19可以分别通过`vgg.vgg13``vgg.vgg19`方法获得。
3. 使用GoogLeNet模型
GoogLeNet在训练阶段使用两个辅助的分类器强化梯度信息并进行额外的正则化。因此`googlenet.googlenet`共返回三个Softmax层,如下面的代码所示:
```python
out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM)
loss1 = paddle.layer.cross_entropy_cost(
input=out1, label=lbl, coeff=0.3)
paddle.evaluator.classification_error(input=out1, label=lbl)
loss2 = paddle.layer.cross_entropy_cost(
input=out2, label=lbl, coeff=0.3)
paddle.evaluator.classification_error(input=out2, label=lbl)
extra_layers = [loss1, loss2]
```
对于两个辅助的输出,这里分别对其计算损失函数并评价错误率,然后将损失作为后文SGD的extra_layers。
4. 使用ResNet模型
ResNet模型可以通过下面的代码获取:
```python
out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM)
```
### 定义损失函数
```python
cost = paddle.layer.classification_cost(input=out, label=lbl)
```
### 创建参数和优化方法
```python
# Create parameters
parameters = paddle.parameters.create(cost)
# Create optimizer
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 *
BATCH_SIZE),
learning_rate=0.001 / BATCH_SIZE,
learning_rate_decay_a=0.1,
learning_rate_decay_b=128000 * 35,
learning_rate_schedule="discexp", )
```
### 定义数据读取方法和事件处理程序
读取数据时需要分别指定训练集和验证集的图像列表文件,这里假设这两个文件分别为`train.list``val.list`
```python
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.test_reader('train.list'),
buf_size=1000),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
reader.train_reader('val.list'),
batch_size=BATCH_SIZE)
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
parameters.to_tar(f)
result = trainer.test(reader=test_reader)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
```
### 定义训练方法
对于AlexNet、VGG和ResNet,可以按下面的代码定义训练方法:
```python
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer)
```
GoogLeNet有两个额外的输出层,因此需要指定`extra_layers`,如下所示:
```python
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer,
extra_layers=extra_layers)
```
### 开始训练
```python
trainer.train(
reader=train_reader, num_passes=200, event_handler=event_handler)
```
......@@ -3,7 +3,7 @@ import paddle.v2 as paddle
__all__ = ['alexnet']
def alexnet(input):
def alexnet(input, class_dim=100):
conv1 = paddle.layer.img_conv(
input=input,
filter_size=11,
......@@ -45,4 +45,6 @@ def alexnet(input):
act=paddle.activation.Relu(),
layer_attr=paddle.attr.Extra(drop_rate=0.5))
return fc2
out = paddle.layer.fc(
input=fc2, size=class_dim, act=paddle.activation.Softmax())
return out
......@@ -53,7 +53,69 @@ def inception(name, input, channels, filter1, filter3R, filter3, filter5R,
return cat
def googlenet(input):
def inception2(name, input, channels, filter1, filter3R, filter3, filter5R,
filter5, proj):
cov1 = paddle.layer.img_conv(
name=name + '_1',
input=input,
filter_size=1,
num_channels=channels,
num_filters=filter1,
stride=1,
padding=0)
cov3r = paddle.layer.img_conv(
name=name + '_3r',
input=input,
filter_size=1,
num_channels=channels,
num_filters=filter3R,
stride=1,
padding=0)
cov3 = paddle.layer.img_conv(
name=name + '_3',
input=cov3r,
filter_size=3,
num_filters=filter3,
stride=1,
padding=1)
cov5r = paddle.layer.img_conv(
name=name + '_5r',
input=input,
filter_size=1,
num_channels=channels,
num_filters=filter5R,
stride=1,
padding=0)
cov5 = paddle.layer.img_conv(
name=name + '_5',
input=cov5r,
filter_size=5,
num_filters=filter5,
stride=1,
padding=2)
pool1 = paddle.layer.img_pool(
name=name + '_max',
input=input,
pool_size=3,
num_channels=channels,
stride=1,
padding=1)
covprj = paddle.layer.img_conv(
name=name + '_proj',
input=pool1,
filter_size=1,
num_filters=proj,
stride=1,
padding=0)
cat = paddle.layer.concat(name=name, input=[cov1, cov3, cov5, covprj])
return cat
def googlenet(input, class_dim=100):
# stage 1
conv1 = paddle.layer.img_conv(
name="conv1",
......@@ -85,23 +147,23 @@ def googlenet(input):
name="pool2", input=conv2_2, pool_size=3, num_channels=192, stride=2)
# stage 3
ince3a = inception("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32)
ince3b = inception("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64)
ince3a = inception2("ince3a", pool2, 192, 64, 96, 128, 16, 32, 32)
ince3b = inception2("ince3b", ince3a, 256, 128, 128, 192, 32, 96, 64)
pool3 = paddle.layer.img_pool(
name="pool3", input=ince3b, num_channels=480, pool_size=3, stride=2)
# stage 4
ince4a = inception("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64)
ince4b = inception("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64)
ince4c = inception("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64)
ince4d = inception("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64)
ince4e = inception("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128)
ince4a = inception2("ince4a", pool3, 480, 192, 96, 208, 16, 48, 64)
ince4b = inception2("ince4b", ince4a, 512, 160, 112, 224, 24, 64, 64)
ince4c = inception2("ince4c", ince4b, 512, 128, 128, 256, 24, 64, 64)
ince4d = inception2("ince4d", ince4c, 512, 112, 144, 288, 32, 64, 64)
ince4e = inception2("ince4e", ince4d, 528, 256, 160, 320, 32, 128, 128)
pool4 = paddle.layer.img_pool(
name="pool4", input=ince4e, num_channels=832, pool_size=3, stride=2)
# stage 5
ince5a = inception("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128)
ince5b = inception("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128)
ince5a = inception2("ince5a", pool4, 832, 256, 160, 320, 32, 128, 128)
ince5b = inception2("ince5b", ince5a, 832, 384, 192, 384, 48, 128, 128)
pool5 = paddle.layer.img_pool(
name="pool5",
input=ince5b,
......@@ -114,6 +176,9 @@ def googlenet(input):
layer_attr=paddle.attr.Extra(drop_rate=0.4),
act=paddle.activation.Linear())
out = paddle.layer.fc(
input=dropout, size=class_dim, act=paddle.activation.Softmax())
# fc for output 1
pool_o1 = paddle.layer.img_pool(
name="pool_o1",
......@@ -135,6 +200,8 @@ def googlenet(input):
size=1024,
layer_attr=paddle.attr.Extra(drop_rate=0.7),
act=paddle.activation.Relu())
out1 = paddle.layer.fc(
input=fc_o1, size=class_dim, act=paddle.activation.Softmax())
# fc for output 2
pool_o2 = paddle.layer.img_pool(
......@@ -157,5 +224,7 @@ def googlenet(input):
size=1024,
layer_attr=paddle.attr.Extra(drop_rate=0.7),
act=paddle.activation.Relu())
out2 = paddle.layer.fc(
input=fc_o2, size=class_dim, act=paddle.activation.Softmax())
return dropout, fc_o1, fc_o2
return out, out1, out2
......@@ -57,7 +57,7 @@ def layer_warp(block_func, input, features, count, stride):
return conv
def resnet_imagenet(input, depth=50):
def resnet_imagenet(input, depth=50, class_dim=100):
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
......@@ -75,10 +75,12 @@ def resnet_imagenet(input, depth=50):
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = paddle.layer.img_pool(
input=res4, pool_size=7, stride=1, pool_type=paddle.pooling.Avg())
return pool2
out = paddle.layer.fc(
input=pool2, size=class_dim, act=paddle.activation.Softmax())
return out
def resnet_cifar10(input, depth=32):
def resnet_cifar10(input, depth=32, class_dim=10):
# depth should be one of 20, 32, 44, 56, 110, 1202
assert (depth - 2) % 6 == 0
n = (depth - 2) / 6
......@@ -90,4 +92,6 @@ def resnet_cifar10(input, depth=32):
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = paddle.layer.img_pool(
input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg())
return pool
out = paddle.layer.fc(
input=pool, size=class_dim, act=paddle.activation.Softmax())
return out
......@@ -35,31 +35,25 @@ def main():
extra_layers = None
if args.model == 'alexnet':
net = alexnet.alexnet(image)
out = alexnet.alexnet(image, class_dim=CLASS_DIM)
elif args.model == 'vgg13':
net = vgg.vgg13(image)
out = vgg.vgg13(image, class_dim=CLASS_DIM)
elif args.model == 'vgg16':
net = vgg.vgg16(image)
out = vgg.vgg16(image, class_dim=CLASS_DIM)
elif args.model == 'vgg19':
net = vgg.vgg19(image)
out = vgg.vgg19(image, class_dim=CLASS_DIM)
elif args.model == 'resnet':
net = resnet.resnet_imagenet(image)
out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM)
elif args.model == 'googlenet':
net, fc_o1, fc_o2 = googlenet.googlenet(image)
out1 = paddle.layer.fc(
input=fc_o1, size=CLASS_DIM, act=paddle.activation.Softmax())
out, out1, out2 = googlenet.googlenet(image, class_dim=CLASS_DIM)
loss1 = paddle.layer.cross_entropy_cost(
input=out1, label=lbl, coeff=0.3)
paddle.evaluator.classification_error(input=out1, label=lbl)
out2 = paddle.layer.fc(
input=fc_o2, size=CLASS_DIM, act=paddle.activation.Softmax())
loss2 = paddle.layer.cross_entropy_cost(
input=out2, label=lbl, coeff=0.3)
paddle.evaluator.classification_error(input=out2, label=lbl)
extra_layers = [loss1, loss2]
out = paddle.layer.fc(
input=net, size=CLASS_DIM, act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=out, label=lbl)
# Create parameters
......
......@@ -17,7 +17,7 @@ import paddle.v2 as paddle
__all__ = ['vgg13', 'vgg16', 'vgg19']
def vgg(input, nums):
def vgg(input, nums, class_dim=100):
def conv_block(input, num_filter, groups, num_channels=None):
return paddle.networks.img_conv_group(
input=input,
......@@ -48,19 +48,21 @@ def vgg(input, nums):
size=fc_dim,
act=paddle.activation.Relu(),
layer_attr=paddle.attr.Extra(drop_rate=0.5))
return fc2
out = paddle.layer.fc(
input=fc2, size=class_dim, act=paddle.activation.Softmax())
return out
def vgg13(input):
def vgg13(input, class_dim=100):
nums = [2, 2, 2, 2, 2]
return vgg(input, nums)
return vgg(input, nums, class_dim)
def vgg16(input):
def vgg16(input, class_dim=100):
nums = [2, 2, 3, 3, 3]
return vgg(input, nums)
return vgg(input, nums, class_dim)
def vgg19(input):
def vgg19(input, class_dim=100):
nums = [2, 2, 4, 4, 4]
return vgg(input, nums)
return vgg(input, nums, class_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册