未验证 提交 b22b286a 编写于 作者: L lujun 提交者: GitHub

Merge pull request #649 from junjun315/02-stuff

update to low level api--02 recognize digits
...@@ -157,18 +157,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni ...@@ -157,18 +157,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。 加载 PaddlePaddle 的 Fluid API 包。
```python ```python
import os
from PIL import Image
import numpy
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from __future__ import print_function from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
``` ```
### Program Functions 配置 ### Program Functions 配置
...@@ -246,8 +240,7 @@ def train_program(): ...@@ -246,8 +240,7 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label) acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc] return predict, [avg_cost, acc]
``` ```
...@@ -269,18 +262,21 @@ def optimizer_program(): ...@@ -269,18 +262,21 @@ def optimizer_program():
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。 `batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```python ```python
BATCH_SIZE = 64
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=64) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=64) paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
``` ```
### Trainer 配置 ### Trainer 配置
现在,我们需要配置 `Trainer``Trainer` 需要接受训练程序 `train_program`, `place` 和优化器 `optimizer` 现在,我们需要构建一个 `Trainer``Trainer` 包含一个训练程序 `train_program`, `place` 和优化器 `optimizer`,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
```python ```python
# 该模型运行在单个CPU上 # 该模型运行在单个CPU上
...@@ -293,47 +289,115 @@ trainer = Trainer( ...@@ -293,47 +289,115 @@ trainer = Trainer(
#### Event Handler 配置 #### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制监控培训进度。 我们可以在训练期间通过调用一个handler函数来监控培训进度。
我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。 我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler` 用来在训练过程中输出训练结果 `event_handler` 用来在训练过程中输出训练结果
```python ```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer def event_handler(pass_id, batch_id, cost):
params_dirname = "recognize_digits_network.inference.model" print("Pass %d, Batch %d, Cost %f" % (pass_id,batch_id, cost))
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
``` ```
```python
from paddle.v2.plot import Ploter
#### 开始训练 train_prompt = "Train cost"
test_prompt = "Test cost"
cost_ploter = Ploter(train_prompt, test_prompt)
# event_handler to plot a figure
def event_handler_plot(ploter_title, step, cost):
cost_ploter.append(ploter_title, step, cost)
cost_ploter.plot()
```
既然我们设置了 `event_handler``data reader`,我们就可以开始训练模型了。 `event_handler_plot` 可以用来在训练过程中画图如下:
![png](./image/train_and_test.png)
#### 开始训练
可以加入我们设置的 `event_handler``data reader`,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order` 用于将数据目录映射到 `train_program` `feed_order` 用于将数据目录映射到 `train_program`
创建一个反馈训练过程中误差的`train_test`
训练完成后,模型参数存入`save_dirname`
```python ```python
trainer.train( # 该模型运行在单个CPU上
num_epochs=5, use_cuda = False # set to True if training with GPU
event_handler=event_handler, place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
reader=train_reader,
feed_order=['img', 'label']) prediction, [avg_loss, acc] = train_program()
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
PASS_NUM = 5
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
save_dirname = "recognize_digits.inference.model"
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
lists = []
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id, metrics[0]))
event_handler_plot(train_prompt, step, metrics[0])
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
event_handler_plot(test_prompt, step, metrics[0])
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname,
["img"], [prediction], exe,
model_filename=None,
params_filename=None)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
``` ```
训练过程是完全自动的,event_handler里打印的日志类似如下所示: 训练过程是完全自动的,event_handler里打印的日志类似如下所示:
...@@ -357,52 +421,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338 ...@@ -357,52 +421,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型 ## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用 `fluid.contrib.inferencer.Inferencer` 接口进行推断。 可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### Inference 配置
`Inference` 需要一个 `infer_func``param_path` 来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```python
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
### 生成预测输入数据 ### 生成预测输入数据
`infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。 `infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```python ```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file): def load_image(file):
im = Image.open(file).convert('L') im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS) im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32) im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0 im = im / 255.0 * 2.0 - 1.0
return im return im
cur_dir = cur_dir = os.getcwd() cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png') tensor_img = load_image(cur_dir + '/image/infer_3.png')
``` ```
### 预测 ### Inference 创建及预测
通过`load_inference_model`来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
现在我们准备做预测。
```python ```python
results = inferencer.infer({'img': img}) inference_scope = fluid.core.Scope()
lab = np.argsort(results) # probs and lab are the results of one batch data with fluid.scope_guard(inference_scope):
print ("Inference result of image/infer_3.png is: %d" % lab[0][0][-1]) # Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, None, None)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
``` ```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结 ## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。 本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
......
...@@ -199,18 +199,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni ...@@ -199,18 +199,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。 加载 PaddlePaddle 的 Fluid API 包。
```python ```python
import os
from PIL import Image
import numpy
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from __future__ import print_function from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
``` ```
### Program Functions 配置 ### Program Functions 配置
...@@ -288,8 +282,7 @@ def train_program(): ...@@ -288,8 +282,7 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label) acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc] return predict, [avg_cost, acc]
``` ```
...@@ -311,18 +304,21 @@ def optimizer_program(): ...@@ -311,18 +304,21 @@ def optimizer_program():
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。 `batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```python ```python
BATCH_SIZE = 64
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=64) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=64) paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
``` ```
### Trainer 配置 ### Trainer 配置
现在,我们需要配置 `Trainer`。`Trainer` 需要接受训练程序 `train_program`, `place` 和优化器 `optimizer` 现在,我们需要构建一个 `Trainer`。`Trainer` 包含一个训练程序 `train_program`, `place` 和优化器 `optimizer`,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
```python ```python
# 该模型运行在单个CPU上 # 该模型运行在单个CPU上
...@@ -335,47 +331,115 @@ trainer = Trainer( ...@@ -335,47 +331,115 @@ trainer = Trainer(
#### Event Handler 配置 #### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制监控培训进度。 我们可以在训练期间通过调用一个handler函数来监控培训进度。
我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。 我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler` 用来在训练过程中输出训练结果 `event_handler` 用来在训练过程中输出训练结果
```python ```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer def event_handler(pass_id, batch_id, cost):
params_dirname = "recognize_digits_network.inference.model" print("Pass %d, Batch %d, Cost %f" % (pass_id,batch_id, cost))
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
``` ```
```python
from paddle.v2.plot import Ploter
#### 开始训练 train_prompt = "Train cost"
test_prompt = "Test cost"
cost_ploter = Ploter(train_prompt, test_prompt)
# event_handler to plot a figure
def event_handler_plot(ploter_title, step, cost):
cost_ploter.append(ploter_title, step, cost)
cost_ploter.plot()
```
既然我们设置了 `event_handler` 和 `data reader`,我们就可以开始训练模型了。 `event_handler_plot` 可以用来在训练过程中画图如下:
![png](./image/train_and_test.png)
#### 开始训练
可以加入我们设置的 `event_handler` 和 `data reader`,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order` 用于将数据目录映射到 `train_program` `feed_order` 用于将数据目录映射到 `train_program`
创建一个反馈训练过程中误差的`train_test`
训练完成后,模型参数存入`save_dirname`中
```python ```python
trainer.train( # 该模型运行在单个CPU上
num_epochs=5, use_cuda = False # set to True if training with GPU
event_handler=event_handler, place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
reader=train_reader,
feed_order=['img', 'label']) prediction, [avg_loss, acc] = train_program()
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
PASS_NUM = 5
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
save_dirname = "recognize_digits.inference.model"
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
lists = []
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id, metrics[0]))
event_handler_plot(train_prompt, step, metrics[0])
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
event_handler_plot(test_prompt, step, metrics[0])
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname,
["img"], [prediction], exe,
model_filename=None,
params_filename=None)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
``` ```
训练过程是完全自动的,event_handler里打印的日志类似如下所示: 训练过程是完全自动的,event_handler里打印的日志类似如下所示:
...@@ -399,52 +463,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338 ...@@ -399,52 +463,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型 ## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用 `fluid.contrib.inferencer.Inferencer` 接口进行推断。 可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### Inference 配置
`Inference` 需要一个 `infer_func` 和 `param_path` 来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```python
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
### 生成预测输入数据 ### 生成预测输入数据
`infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。 `infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```python ```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file): def load_image(file):
im = Image.open(file).convert('L') im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS) im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32) im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0 im = im / 255.0 * 2.0 - 1.0
return im return im
cur_dir = cur_dir = os.getcwd() cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png') tensor_img = load_image(cur_dir + '/image/infer_3.png')
``` ```
### 预测 ### Inference 创建及预测
通过`load_inference_model`来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
现在我们准备做预测。
```python ```python
results = inferencer.infer({'img': img}) inference_scope = fluid.core.Scope()
lab = np.argsort(results) # probs and lab are the results of one batch data with fluid.scope_guard(inference_scope):
print ("Inference result of image/infer_3.png is: %d" % lab[0][0][-1]) # Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, None, None)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
``` ```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结 ## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。 本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
from PIL import Image from PIL import Image
import numpy as np import numpy
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
try: BATCH_SIZE = 64
from paddle.fluid.contrib.trainer import * PASS_NUM = 5
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
def softmax_regression(): def loss_net(hidden, label):
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
predict = fluid.layers.fc(input=img, size=10, act='softmax') loss = fluid.layers.cross_entropy(input=prediction, label=label)
return predict avg_loss = fluid.layers.mean(loss)
acc = fluid.layers.accuracy(input=prediction, label=label)
return prediction, avg_loss, acc
def multilayer_perceptron(): def multilayer_perceptron(img, label):
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') img = fluid.layers.fc(input=img, size=200, act='tanh')
# first fully-connected layer, using ReLu as its activation function hidden = fluid.layers.fc(input=img, size=200, act='tanh')
hidden = fluid.layers.fc(input=img, size=128, act='relu') return loss_net(hidden, label)
# second fully-connected layer, using ReLu as its activation function
hidden = fluid.layers.fc(input=hidden, size=64, act='relu')
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
return prediction
def convolutional_neural_network(): def softmax_regression(img, label):
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') return loss_net(img, label)
# first conv pool
def convolutional_neural_network(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool( conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img, input=img,
filter_size=5, filter_size=5,
...@@ -45,7 +51,6 @@ def convolutional_neural_network(): ...@@ -45,7 +51,6 @@ def convolutional_neural_network():
pool_stride=2, pool_stride=2,
act="relu") act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
# second conv pool
conv_pool_2 = fluid.nets.simple_img_conv_pool( conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1, input=conv_pool_1,
filter_size=5, filter_size=5,
...@@ -53,99 +58,160 @@ def convolutional_neural_network(): ...@@ -53,99 +58,160 @@ def convolutional_neural_network():
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
act="relu") act="relu")
# output layer with softmax activation function. size = 10 since there are only 10 possible digits. return loss_net(conv_pool_2, label)
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return prediction
def train_program(): def train(nn_type,
label = fluid.layers.data(name='label', shape=[1], dtype='int64') use_cuda,
save_dirname=None,
# Here we can build the prediction network in different ways. Please model_filename=None,
# predict = softmax_regression() # uncomment for Softmax params_filename=None):
# predict = multilayer_perceptron() # uncomment for MLP if use_cuda and not fluid.core.is_compiled_with_cuda():
predict = convolutional_neural_network() # uncomment for LeNet5 return
# Calculate the cost from the prediction and label. img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
cost = fluid.layers.cross_entropy(input=predict, label=label) label = fluid.layers.data(name='label', shape=[1], dtype='int64')
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
if nn_type == 'softmax_regression':
net_conf = softmax_regression
elif nn_type == 'multilayer_perceptron':
net_conf = multilayer_perceptron
else:
net_conf = convolutional_neural_network
prediction, avg_loss, acc = net_conf(img, label)
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
def optimizer_program(): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
return fluid.optimizer.Adam(learning_rate=0.001)
exe = fluid.Executor(place)
def main():
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500), paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=64) batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64) exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
use_cuda = False # set to True if training with GPU epochs = [epoch_id for epoch_id in range(PASS_NUM)]
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_program)
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = [] lists = []
step = 0
def event_handler(event): for epoch_id in epochs:
if isinstance(event, EndStepEvent): for step_id, data in enumerate(train_reader()):
if event.step % 100 == 0: metrics = exe.run(
# event.metrics maps with train program return arguments. main_program,
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example. feed=feeder.feed(data),
print("Pass %d, Batch %d, Cost %f" % (event.step, event.epoch, fetch_list=[avg_loss, acc])
event.metrics[0])) if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id,
if isinstance(event, EndEpochEvent): metrics[0]))
avg_cost, acc = trainer.test( step += 1
reader=test_reader, feed_order=['img', 'label']) # test for epoch
avg_loss_val, acc_val = train_test(
print("Test with Epoch %d, avg_cost: %s, acc: %s" % train_test_program=test_program,
(event.epoch, avg_cost, acc)) train_test_reader=test_reader,
train_test_feed=feeder)
# save parameters
trainer.save_params(params_dirname) print("Test with Epoch %d, avg_cost: %s, acc: %s" %
lists.append((event.epoch, avg_cost, acc)) (epoch_id, avg_loss_val, acc_val))
lists.append((epoch_id, avg_loss_val, acc_val))
# Train the model now if save_dirname is not None:
trainer.train( fluid.io.save_inference_model(
num_epochs=5, save_dirname, ["img"], [prediction],
event_handler=event_handler, exe,
reader=train_reader, model_filename=model_filename,
feed_order=['img', 'label']) params_filename=params_filename)
# find the best pass # find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0] best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1])) print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100)) print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
def infer(use_cuda,
save_dirname=None,
model_filename=None,
params_filename=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
def load_image(file): def load_image(file):
im = Image.open(file).convert('L') im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS) im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32) im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0 im = im / 255.0 * 2.0 - 1.0
return im return im
cur_dir = os.path.dirname(os.path.realpath(__file__)) cur_dir = os.path.dirname(os.path.realpath(__file__))
img = load_image(cur_dir + '/image/infer_3.png') tensor_img = load_image(cur_dir + '/image/infer_3.png')
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression inference_scope = fluid.core.Scope()
# infer_func=multilayer_perceptron, # uncomment for MLP with fluid.scope_guard(inference_scope):
infer_func=convolutional_neural_network, # uncomment for LeNet5 # Use fluid.io.load_inference_model to obtain the inference program desc,
param_path=params_dirname, # the feed_target_names (the names of variables that will be feeded
place=place) # data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
results = inferencer.infer({'img': img}) [inference_program, feed_target_names,
lab = np.argsort(results) # probs and lab are the results of one batch data fetch_targets] = fluid.io.load_inference_model(
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1]) save_dirname, exe, model_filename, params_filename)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(
inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
def main(use_cuda, nn_type):
model_filename = None
params_filename = None
save_dirname = "recognize_digits_" + nn_type + ".inference.model"
# call train() with is_local argument to run distributed train
train(
nn_type=nn_type,
use_cuda=use_cuda,
save_dirname=save_dirname,
model_filename=model_filename,
params_filename=params_filename)
infer(
use_cuda=use_cuda,
save_dirname=save_dirname,
model_filename=model_filename,
params_filename=params_filename)
if __name__ == '__main__': if __name__ == '__main__':
main() use_cuda = False
# predict = 'softmax_regression' # uncomment for Softmax
# predict = 'multilayer_perceptron' # uncomment for MLP
predict = 'convolutional_neural_network' # uncomment for LeNet5
main(use_cuda=use_cuda, nn_type=predict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册