Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
644a9d36
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
644a9d36
编写于
12月 03, 2018
作者:
L
lujun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update to low level api--02 recognize digits,test=develop
上级
fa35415f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
430 addition
and
236 deletion
+430
-236
02.recognize_digits/README.cn.md
02.recognize_digits/README.cn.md
+134
-70
02.recognize_digits/index.cn.html
02.recognize_digits/index.cn.html
+134
-70
02.recognize_digits/train.py
02.recognize_digits/train.py
+162
-96
未找到文件。
02.recognize_digits/README.cn.md
浏览文件 @
644a9d36
...
...
@@ -157,18 +157,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。
```
python
import
os
from
PIL
import
Image
import
numpy
import
paddle
import
paddle.fluid
as
fluid
from
__future__
import
print_function
try
:
from
paddle.fluid.contrib.trainer
import
*
from
paddle.fluid.contrib.inferencer
import
*
except
ImportError
:
print
(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib"
,
file
=
sys
.
stderr
)
from
paddle.fluid.trainer
import
*
from
paddle.fluid.inferencer
import
*
```
### Program Functions 配置
...
...
@@ -246,8 +240,7 @@ def train_program():
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
return
[
avg_cost
,
acc
]
return
predict
,
[
avg_cost
,
acc
]
```
...
...
@@ -269,18 +262,21 @@ def optimizer_program():
`batch`
是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```
python
BATCH_SIZE
=
64
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
batch_size
=
64
)
batch_size
=
BATCH_SIZE
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
64
)
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
BATCH_SIZE
)
```
### Trainer 配置
现在,我们需要
配置
`Trainer`
。
`Trainer`
需要接受训练程序
`train_program`
,
`place`
和优化器
`optimizer`
。
现在,我们需要
构建一个
`Trainer`
。
`Trainer`
包含一个训练程序
`train_program`
,
`place`
和优化器
`optimizer`
,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
。
```
python
# 该模型运行在单个CPU上
...
...
@@ -293,47 +289,115 @@ trainer = Trainer(
#### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制
监控培训进度。
我们可以在训练期间通过调用一个handler函数来
监控培训进度。
我们将在这里演示两个
`event_handler`
程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler`
用来在训练过程中输出训练结果
```
python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname
=
"recognize_digits_network.inference.model"
lists
=
[]
def
event_handler
(
event
):
if
isinstance
(
event
,
EndStepEvent
):
if
event
.
step
%
100
==
0
:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print
(
"Pass %d, Batch %d, Cost %f"
%
(
event
.
step
,
event
.
epoch
,
event
.
metrics
[
0
]))
if
isinstance
(
event
,
EndEpochEvent
):
avg_cost
,
acc
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
[
'img'
,
'label'
])
print
(
"Test with Epoch %d, avg_cost: %s, acc: %s"
%
(
event
.
epoch
,
avg_cost
,
acc
))
# save parameters
trainer
.
save_params
(
params_dirname
)
lists
.
append
((
event
.
epoch
,
avg_cost
,
acc
))
def
event_handler
(
pass_id
,
batch_id
,
cost
):
print
(
"Pass %d, Batch %d, Cost %f"
%
(
pass_id
,
batch_id
,
cost
))
```
```
python
from
paddle.v2.plot
import
Ploter
#### 开始训练
train_title
=
"Train cost"
test_title
=
"Test cost"
cost_ploter
=
Ploter
(
train_title
,
test_title
)
# event_handler to plot a figure
def
event_handler_plot
(
ploter_title
,
step
,
cost
):
cost_ploter
.
append
(
ploter_title
,
step
,
cost
)
cost_ploter
.
plot
()
```
既然我们设置了
`event_handler`
和
`data reader`
,我们就可以开始训练模型了。
`event_handler_plot`
可以用来在训练过程中画图如下:
![
png
](
./image/train_and_test.png
)
#### 开始训练
可以加入我们设置的
`event_handler`
和
`data reader`
,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order`
用于将数据目录映射到
`train_program`
创建一个反馈训练过程中误差的
`train_test`
训练完成后,模型参数存入
`save_dirname`
中
```
python
trainer
.
train
(
num_epochs
=
5
,
event_handler
=
event_handler
,
reader
=
train_reader
,
feed_order
=
[
'img'
,
'label'
])
# 该模型运行在单个CPU上
use_cuda
=
False
# set to True if training with GPU
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
prediction
,
[
avg_loss
,
acc
]
=
train_program
()
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
img
,
label
],
place
=
place
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
optimizer
.
minimize
(
avg_loss
)
PASS_NUM
=
5
epochs
=
[
epoch_id
for
epoch_id
in
range
(
PASS_NUM
)]
save_dirname
=
"recognize_digits.inference.model"
def
train_test
(
train_test_program
,
train_test_feed
,
train_test_reader
):
acc_set
=
[]
avg_loss_set
=
[]
for
test_data
in
train_test_reader
():
acc_np
,
avg_loss_np
=
exe
.
run
(
program
=
train_test_program
,
feed
=
train_test_feed
.
feed
(
test_data
),
fetch_list
=
[
acc
,
avg_loss
])
acc_set
.
append
(
float
(
acc_np
))
avg_loss_set
.
append
(
float
(
avg_loss_np
))
# get test acc and loss
acc_val_mean
=
numpy
.
array
(
acc_set
).
mean
()
avg_loss_val_mean
=
numpy
.
array
(
avg_loss_set
).
mean
()
return
avg_loss_val_mean
,
acc_val_mean
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
main_program
=
fluid
.
default_main_program
()
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
lists
=
[]
step
=
0
for
epoch_id
in
epochs
:
for
step_id
,
data
in
enumerate
(
train_reader
()):
metrics
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_loss
,
acc
])
if
step
%
100
==
0
:
print
(
"Pass %d, Batch %d, Cost %f"
%
(
step
,
epoch_id
,
metrics
[
0
]))
event_handler_plot
(
train_title
,
step
,
metrics
[
0
])
step
+=
1
# test for epoch
avg_loss_val
,
acc_val
=
train_test
(
train_test_program
=
test_program
,
train_test_reader
=
test_reader
,
train_test_feed
=
feeder
)
print
(
"Test with Epoch %d, avg_cost: %s, acc: %s"
%
(
epoch_id
,
avg_loss_val
,
acc_val
))
event_handler_plot
(
test_title
,
step
,
metrics
[
0
])
lists
.
append
((
epoch_id
,
avg_loss_val
,
acc_val
))
if
save_dirname
is
not
None
:
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
"img"
],
[
prediction
],
exe
,
model_filename
=
None
,
params_filename
=
None
)
# find the best pass
best
=
sorted
(
lists
,
key
=
lambda
list
:
float
(
list
[
1
]))[
0
]
print
(
'Best pass is %s, testing Avgcost is %s'
%
(
best
[
0
],
best
[
1
]))
print
(
'The classification accuracy is %.2f%%'
%
(
float
(
best
[
2
])
*
100
))
```
训练过程是完全自动的,event_handler里打印的日志类似如下所示:
...
...
@@ -357,52 +421,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用
`fluid.contrib.inferencer.Inferencer`
接口进行推断。
### Inference 配置
`Inference`
需要一个
`infer_func`
和
`param_path`
来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```
python
inferencer
=
Inferencer
(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func
=
convolutional_neural_network
,
# uncomment for LeNet5
param_path
=
params_dirname
,
place
=
place
)
```
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### 生成预测输入数据
`infer_3.png`
是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```
python
# Prepare the test image
import
os
import
numpy
as
np
from
PIL
import
Image
def
load_image
(
file
):
im
=
Image
.
open
(
file
).
convert
(
'L'
)
im
=
im
.
resize
((
28
,
28
),
Image
.
ANTIALIAS
)
im
=
n
p
.
array
(
im
).
reshape
(
1
,
1
,
28
,
28
).
astype
(
np
.
float32
)
im
=
n
umpy
.
array
(
im
).
reshape
(
1
,
1
,
28
,
28
).
astype
(
numpy
.
float32
)
im
=
im
/
255.0
*
2.0
-
1.0
return
im
cur_dir
=
cur_dir
=
os
.
getcwd
()
img
=
load_image
(
cur_dir
+
'/image/infer_3.png'
)
tensor_
img
=
load_image
(
cur_dir
+
'/image/infer_3.png'
)
```
### 预测
现在我们准备做预测。
### Inference 创建及预测
通过
`load_inference_model`
来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
```
python
results
=
inferencer
.
infer
({
'img'
:
img
})
lab
=
np
.
argsort
(
results
)
# probs and lab are the results of one batch data
print
(
"Inference result of image/infer_3.png is: %d"
%
lab
[
0
][
0
][
-
1
])
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
,
None
,
None
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
lab
=
numpy
.
argsort
(
results
)
print
(
"Inference result of image/infer_3.png is: %d"
%
lab
[
0
][
0
][
-
1
])
```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
...
...
02.recognize_digits/index.cn.html
浏览文件 @
644a9d36
...
...
@@ -199,18 +199,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。
```python
import os
from PIL import Image
import numpy
import paddle
import paddle.fluid as fluid
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
### Program Functions 配置
...
...
@@ -288,8 +282,7 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
return predict, [avg_cost, acc]
```
...
...
@@ -311,18 +304,21 @@ def optimizer_program():
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```python
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=
64
)
batch_size=
BATCH_SIZE
)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=
64
)
paddle.dataset.mnist.test(), batch_size=
BATCH_SIZE
)
```
### Trainer 配置
现在,我们需要
配置 `Trainer`。`Trainer` 需要接受训练程序 `train_program`, `place` 和优化器 `optimizer`
。
现在,我们需要
构建一个 `Trainer`。`Trainer` 包含一个训练程序 `train_program`, `place` 和优化器 `optimizer`,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
。
```python
# 该模型运行在单个CPU上
...
...
@@ -335,47 +331,115 @@ trainer = Trainer(
#### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制
监控培训进度。
我们可以在训练期间通过调用一个handler函数来
监控培训进度。
我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler` 用来在训练过程中输出训练结果
```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
def event_handler(pass_id, batch_id, cost):
print("Pass %d, Batch %d, Cost %f" % (pass_id,batch_id, cost))
```
```python
from paddle.v2.plot import Ploter
#### 开始训练
train_title = "Train cost"
test_title = "Test cost"
cost_ploter = Ploter(train_title, test_title)
# event_handler to plot a figure
def event_handler_plot(ploter_title, step, cost):
cost_ploter.append(ploter_title, step, cost)
cost_ploter.plot()
```
既然我们设置了 `event_handler` 和 `data reader`,我们就可以开始训练模型了。
`event_handler_plot` 可以用来在训练过程中画图如下:
![png](./image/train_and_test.png)
#### 开始训练
可以加入我们设置的 `event_handler` 和 `data reader`,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order` 用于将数据目录映射到 `train_program`
创建一个反馈训练过程中误差的`train_test`
训练完成后,模型参数存入`save_dirname`中
```python
trainer.train(
num_epochs=5,
event_handler=event_handler,
reader=train_reader,
feed_order=['img', 'label'])
# 该模型运行在单个CPU上
use_cuda = False # set to True if training with GPU
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
prediction, [avg_loss, acc] = train_program()
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
PASS_NUM = 5
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
save_dirname = "recognize_digits.inference.model"
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
lists = []
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id, metrics[0]))
event_handler_plot(train_title, step, metrics[0])
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
event_handler_plot(test_title, step, metrics[0])
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname,
["img"], [prediction], exe,
model_filename=None,
params_filename=None)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
```
训练过程是完全自动的,event_handler里打印的日志类似如下所示:
...
...
@@ -399,52 +463,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用 `fluid.contrib.inferencer.Inferencer` 接口进行推断。
### Inference 配置
`Inference` 需要一个 `infer_func` 和 `param_path` 来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```python
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### 生成预测输入数据
`infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = n
p.array(im).reshape(1, 1, 28, 28).astype(np
.float32)
im = n
umpy.array(im).reshape(1, 1, 28, 28).astype(numpy
.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
tensor_
img = load_image(cur_dir + '/image/infer_3.png')
```
### 预测
现在我们准备做预测。
### Inference 创建及预测
通过`load_inference_model`来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
```python
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print ("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, None, None)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
...
...
02.recognize_digits/train.py
浏览文件 @
644a9d36
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
from
PIL
import
Image
import
numpy
as
np
import
numpy
import
paddle
import
paddle.fluid
as
fluid
try
:
from
paddle.fluid.contrib.trainer
import
*
from
paddle.fluid.contrib.inferencer
import
*
except
ImportError
:
print
(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib"
,
file
=
sys
.
stderr
)
from
paddle.fluid.trainer
import
*
from
paddle.fluid.inferencer
import
*
BATCH_SIZE
=
64
PASS_NUM
=
5
def
softmax_regression
():
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
predict
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
10
,
act
=
'softmax'
)
return
predict
def
loss_net
(
hidden
,
label
):
prediction
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
)
return
prediction
,
avg_loss
,
acc
def
multilayer_perceptron
():
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
# first fully-connected layer, using ReLu as its activation function
hidden
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
128
,
act
=
'relu'
)
# second fully-connected layer, using ReLu as its activation function
hidden
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
64
,
act
=
'relu'
)
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
prediction
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
10
,
act
=
'softmax'
)
return
prediction
def
multilayer_perceptron
(
img
,
label
):
img
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
200
,
act
=
'tanh'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
img
,
size
=
200
,
act
=
'tanh'
)
return
loss_net
(
hidden
,
label
)
def
convolutional_neural_network
():
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
# first conv pool
def
softmax_regression
(
img
,
label
):
return
loss_net
(
img
,
label
)
def
convolutional_neural_network
(
img
,
label
):
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
img
,
filter_size
=
5
,
...
...
@@ -45,7 +51,6 @@ def convolutional_neural_network():
pool_stride
=
2
,
act
=
"relu"
)
conv_pool_1
=
fluid
.
layers
.
batch_norm
(
conv_pool_1
)
# second conv pool
conv_pool_2
=
fluid
.
nets
.
simple_img_conv_pool
(
input
=
conv_pool_1
,
filter_size
=
5
,
...
...
@@ -53,99 +58,160 @@ def convolutional_neural_network():
pool_size
=
2
,
pool_stride
=
2
,
act
=
"relu"
)
# output layer with softmax activation function. size = 10 since there are only 10 possible digits.
prediction
=
fluid
.
layers
.
fc
(
input
=
conv_pool_2
,
size
=
10
,
act
=
'softmax'
)
return
prediction
return
loss_net
(
conv_pool_2
,
label
)
def
train
_program
():
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
# Here we can build the prediction network in different ways. Please
# predict = softmax_regression() # uncomment for Softmax
# predict = multilayer_perceptron() # uncomment for MLP
predict
=
convolutional_neural_network
()
# uncomment for LeNet5
def
train
(
nn_type
,
use_cuda
,
save_dirname
=
None
,
model_filename
=
None
,
params_filename
=
None
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
# Calculate the cost from the prediction and label.
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
predict
,
label
=
label
)
return
[
avg_cost
,
acc
]
img
=
fluid
.
layers
.
data
(
name
=
'img'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
if
nn_type
==
'softmax_regression'
:
net_conf
=
softmax_regression
elif
nn_type
==
'multilayer_perceptron'
:
net_conf
=
multilayer_perceptron
else
:
net_conf
=
convolutional_neural_network
prediction
,
avg_loss
,
acc
=
net_conf
(
img
,
label
)
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
optimizer
.
minimize
(
avg_loss
)
def
train_test
(
train_test_program
,
train_test_feed
,
train_test_reader
):
acc_set
=
[]
avg_loss_set
=
[]
for
test_data
in
train_test_reader
():
acc_np
,
avg_loss_np
=
exe
.
run
(
program
=
train_test_program
,
feed
=
train_test_feed
.
feed
(
test_data
),
fetch_list
=
[
acc
,
avg_loss
])
acc_set
.
append
(
float
(
acc_np
))
avg_loss_set
.
append
(
float
(
avg_loss_np
))
# get test acc and loss
acc_val_mean
=
numpy
.
array
(
acc_set
).
mean
()
avg_loss_val_mean
=
numpy
.
array
(
avg_loss_set
).
mean
()
return
avg_loss_val_mean
,
acc_val_mean
def
optimizer_program
():
return
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
def
main
():
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
500
),
batch_size
=
64
)
batch_size
=
BATCH_SIZE
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
BATCH_SIZE
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
img
,
label
],
place
=
place
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
64
)
use_cuda
=
False
# set to True if training with GPU
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
trainer
=
Trainer
(
train_func
=
train_program
,
place
=
place
,
optimizer_func
=
optimizer_program
)
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname
=
"recognize_digits_network.inference.model"
exe
.
run
(
fluid
.
default_startup_program
())
main_program
=
fluid
.
default_main_program
()
epochs
=
[
epoch_id
for
epoch_id
in
range
(
PASS_NUM
)]
lists
=
[]
def
event_handler
(
event
)
:
if
isinstance
(
event
,
EndStepEvent
):
if
event
.
step
%
100
==
0
:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print
(
"Pass %d, Batch %d, Cost %f"
%
(
event
.
step
,
event
.
epoch
,
event
.
metrics
[
0
]))
if
isinstance
(
event
,
EndEpochEvent
):
avg_cost
,
acc
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
[
'img'
,
'label'
])
print
(
"Test with Epoch %d, avg_cost: %s, acc: %s"
%
(
event
.
epoch
,
avg_cost
,
acc
))
# save parameters
trainer
.
save_params
(
params_dirname
)
lists
.
append
((
event
.
epoch
,
avg_cost
,
acc
))
# Train the model now
trainer
.
train
(
num_epochs
=
5
,
event_handler
=
event_handler
,
reader
=
train_reader
,
feed_order
=
[
'img'
,
'label'
]
)
step
=
0
for
epoch_id
in
epochs
:
for
step_id
,
data
in
enumerate
(
train_reader
()
):
metrics
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_loss
,
acc
])
if
step
%
100
==
0
:
print
(
"Pass %d, Batch %d, Cost %f"
%
(
step
,
epoch_id
,
metrics
[
0
]))
step
+=
1
# test for epoch
avg_loss_val
,
acc_val
=
train_test
(
train_test_program
=
test_program
,
train_test_reader
=
test_reader
,
train_test_feed
=
feeder
)
print
(
"Test with Epoch %d, avg_cost: %s, acc: %s"
%
(
epoch_id
,
avg_loss_val
,
acc_val
))
lists
.
append
((
epoch_id
,
avg_loss_val
,
acc_val
))
if
save_dirname
is
not
None
:
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
"img"
],
[
prediction
]
,
exe
,
model_filename
=
model_filename
,
params_filename
=
params_filename
)
# find the best pass
best
=
sorted
(
lists
,
key
=
lambda
list
:
float
(
list
[
1
]))[
0
]
print
(
'Best pass is %s, testing Avgcost is %s'
%
(
best
[
0
],
best
[
1
]))
print
(
'The classification accuracy is %.2f%%'
%
(
float
(
best
[
2
])
*
100
))
def
infer
(
use_cuda
,
save_dirname
=
None
,
model_filename
=
None
,
params_filename
=
None
):
if
save_dirname
is
None
:
return
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
def
load_image
(
file
):
im
=
Image
.
open
(
file
).
convert
(
'L'
)
im
=
im
.
resize
((
28
,
28
),
Image
.
ANTIALIAS
)
im
=
n
p
.
array
(
im
).
reshape
(
1
,
1
,
28
,
28
).
astype
(
np
.
float32
)
im
=
n
umpy
.
array
(
im
).
reshape
(
1
,
1
,
28
,
28
).
astype
(
numpy
.
float32
)
im
=
im
/
255.0
*
2.0
-
1.0
return
im
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
img
=
load_image
(
cur_dir
+
'/image/infer_3.png'
)
inferencer
=
Inferencer
(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func
=
convolutional_neural_network
,
# uncomment for LeNet5
param_path
=
params_dirname
,
place
=
place
)
results
=
inferencer
.
infer
({
'img'
:
img
})
lab
=
np
.
argsort
(
results
)
# probs and lab are the results of one batch data
print
(
"Inference result of image/infer_3.png is: %d"
%
lab
[
0
][
0
][
-
1
])
tensor_img
=
load_image
(
cur_dir
+
'/image/infer_3.png'
)
inference_scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
inference_scope
):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
save_dirname
,
exe
,
model_filename
,
params_filename
)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
tensor_img
},
fetch_list
=
fetch_targets
)
lab
=
numpy
.
argsort
(
results
)
print
(
"Inference result of image/infer_3.png is: %d"
%
lab
[
0
][
0
][
-
1
])
def
main
(
use_cuda
,
nn_type
):
model_filename
=
None
params_filename
=
None
save_dirname
=
"recognize_digits_"
+
nn_type
+
".inference.model"
# call train() with is_local argument to run distributed train
train
(
nn_type
=
nn_type
,
use_cuda
=
use_cuda
,
save_dirname
=
save_dirname
,
model_filename
=
model_filename
,
params_filename
=
params_filename
)
infer
(
use_cuda
=
use_cuda
,
save_dirname
=
save_dirname
,
model_filename
=
model_filename
,
params_filename
=
params_filename
)
if
__name__
==
'__main__'
:
main
()
use_cuda
=
False
# predict = 'softmax_regression' # uncomment for Softmax
# predict = 'multilayer_perceptron' # uncomment for MLP
predict
=
'convolutional_neural_network'
# uncomment for LeNet5
main
(
use_cuda
=
use_cuda
,
nn_type
=
predict
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录