提交 644a9d36 编写于 作者: L lujun

update to low level api--02 recognize digits,test=develop

上级 fa35415f
......@@ -157,18 +157,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。
```python
import os
from PIL import Image
import numpy
import paddle
import paddle.fluid as fluid
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
### Program Functions 配置
......@@ -246,8 +240,7 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
return predict, [avg_cost, acc]
```
......@@ -269,18 +262,21 @@ def optimizer_program():
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```python
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=64)
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=64)
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
```
### Trainer 配置
现在,我们需要配置 `Trainer``Trainer` 需要接受训练程序 `train_program`, `place` 和优化器 `optimizer`
现在,我们需要构建一个 `Trainer``Trainer` 包含一个训练程序 `train_program`, `place` 和优化器 `optimizer`,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
```python
# 该模型运行在单个CPU上
......@@ -293,47 +289,115 @@ trainer = Trainer(
#### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制监控培训进度。
我们可以在训练期间通过调用一个handler函数来监控培训进度。
我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler` 用来在训练过程中输出训练结果
```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
def event_handler(pass_id, batch_id, cost):
print("Pass %d, Batch %d, Cost %f" % (pass_id,batch_id, cost))
```
```python
from paddle.v2.plot import Ploter
#### 开始训练
train_title = "Train cost"
test_title = "Test cost"
cost_ploter = Ploter(train_title, test_title)
# event_handler to plot a figure
def event_handler_plot(ploter_title, step, cost):
cost_ploter.append(ploter_title, step, cost)
cost_ploter.plot()
```
既然我们设置了 `event_handler``data reader`,我们就可以开始训练模型了。
`event_handler_plot` 可以用来在训练过程中画图如下:
![png](./image/train_and_test.png)
#### 开始训练
可以加入我们设置的 `event_handler``data reader`,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order` 用于将数据目录映射到 `train_program`
创建一个反馈训练过程中误差的`train_test`
训练完成后,模型参数存入`save_dirname`
```python
trainer.train(
num_epochs=5,
event_handler=event_handler,
reader=train_reader,
feed_order=['img', 'label'])
# 该模型运行在单个CPU上
use_cuda = False # set to True if training with GPU
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
prediction, [avg_loss, acc] = train_program()
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
PASS_NUM = 5
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
save_dirname = "recognize_digits.inference.model"
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
lists = []
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id, metrics[0]))
event_handler_plot(train_title, step, metrics[0])
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
event_handler_plot(test_title, step, metrics[0])
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname,
["img"], [prediction], exe,
model_filename=None,
params_filename=None)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
```
训练过程是完全自动的,event_handler里打印的日志类似如下所示:
......@@ -357,52 +421,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用 `fluid.contrib.inferencer.Inferencer` 接口进行推断。
### Inference 配置
`Inference` 需要一个 `infer_func``param_path` 来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```python
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### 生成预测输入数据
`infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
tensor_img = load_image(cur_dir + '/image/infer_3.png')
```
### 预测
现在我们准备做预测。
### Inference 创建及预测
通过`load_inference_model`来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
```python
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print ("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, None, None)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
......
......@@ -199,18 +199,12 @@ PaddlePaddle在API中提供了自动加载[MNIST](http://yann.lecun.com/exdb/mni
加载 PaddlePaddle 的 Fluid API 包。
```python
import os
from PIL import Image
import numpy
import paddle
import paddle.fluid as fluid
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
### Program Functions 配置
......@@ -288,8 +282,7 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
return predict, [avg_cost, acc]
```
......@@ -311,18 +304,21 @@ def optimizer_program():
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader。在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minibatch。
```python
BATCH_SIZE = 64
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=64)
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=64)
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
```
### Trainer 配置
现在,我们需要配置 `Trainer`。`Trainer` 需要接受训练程序 `train_program`, `place` 和优化器 `optimizer`
现在,我们需要构建一个 `Trainer`。`Trainer` 包含一个训练程序 `train_program`, `place` 和优化器 `optimizer`,并包含训练迭代、检查训练期间测试误差以及保存所需要用来预测的模型参数
```python
# 该模型运行在单个CPU上
......@@ -335,47 +331,115 @@ trainer = Trainer(
#### Event Handler 配置
Fluid API 在训练期间为回调函数提供了一个钩子。用户能够通过机制监控培训进度。
我们可以在训练期间通过调用一个handler函数来监控培训进度。
我们将在这里演示两个 `event_handler` 程序。请随意修改 Jupyter 笔记本 ,看看有什么不同。
`event_handler` 用来在训练过程中输出训练结果
```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
def event_handler(pass_id, batch_id, cost):
print("Pass %d, Batch %d, Cost %f" % (pass_id,batch_id, cost))
```
```python
from paddle.v2.plot import Ploter
#### 开始训练
train_title = "Train cost"
test_title = "Test cost"
cost_ploter = Ploter(train_title, test_title)
# event_handler to plot a figure
def event_handler_plot(ploter_title, step, cost):
cost_ploter.append(ploter_title, step, cost)
cost_ploter.plot()
```
既然我们设置了 `event_handler` 和 `data reader`,我们就可以开始训练模型了。
`event_handler_plot` 可以用来在训练过程中画图如下:
![png](./image/train_and_test.png)
#### 开始训练
可以加入我们设置的 `event_handler` 和 `data reader`,然后就可以开始训练模型了。
设置一些运行需要的参数,配置数据描述
`feed_order` 用于将数据目录映射到 `train_program`
创建一个反馈训练过程中误差的`train_test`
训练完成后,模型参数存入`save_dirname`中
```python
trainer.train(
num_epochs=5,
event_handler=event_handler,
reader=train_reader,
feed_order=['img', 'label'])
# 该模型运行在单个CPU上
use_cuda = False # set to True if training with GPU
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
prediction, [avg_loss, acc] = train_program()
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
PASS_NUM = 5
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
save_dirname = "recognize_digits.inference.model"
def train_test(train_test_program,
train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
test_program = fluid.default_main_program().clone(for_test=True)
lists = []
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id, metrics[0]))
event_handler_plot(train_title, step, metrics[0])
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %(epoch_id, avg_loss_val, acc_val))
event_handler_plot(test_title, step, metrics[0])
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname,
["img"], [prediction], exe,
model_filename=None,
params_filename=None)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
```
训练过程是完全自动的,event_handler里打印的日志类似如下所示:
......@@ -399,52 +463,52 @@ Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
## 应用模型
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用 `fluid.contrib.inferencer.Inferencer` 接口进行推断。
### Inference 配置
`Inference` 需要一个 `infer_func` 和 `param_path` 来设置网络和经过训练的参数。
我们可以简单地插入在此之前定义的分类器。
```python
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
可以使用训练好的模型对手写体数字图片进行分类,下面程序展示了如何使用训练好的模型进行推断。
### 生成预测输入数据
`infer_3.png` 是数字 3 的一个示例图像。把它变成一个 numpy 数组以匹配数据馈送格式。
```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
tensor_img = load_image(cur_dir + '/image/infer_3.png')
```
### 预测
现在我们准备做预测。
### Inference 创建及预测
通过`load_inference_model`来设置网络和经过训练的参数。我们可以简单地插入在此之前定义的分类器。
```python
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print ("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, None, None)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
```
### 预测结果
如果顺利,预测结果输入如下:
`Inference result of image/infer_3.png is: 3`
## 总结
本教程的softmax回归、多层感知器和卷积神经网络是最基础的深度学习模型,后续章节中复杂的神经网络都是从它们衍生出来的,因此这几个模型对之后的学习大有裨益。同时,我们也观察到从最简单的softmax回归变换到稍复杂的卷积神经网络的时候,MNIST数据集上的识别准确率有了大幅度的提升,原因是卷积层具有局部连接和共享权重的特性。在之后学习新模型的时候,希望大家也要深入到新模型相比原模型带来效果提升的关键之处。此外,本教程还介绍了PaddlePaddle模型搭建的基本流程,从dataprovider的编写、网络层的构建,到最后的训练和预测。对这个流程熟悉以后,大家就可以用自己的数据,定义自己的网络模型,并完成自己的训练和预测任务了。
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
from PIL import Image
import numpy as np
import numpy
import paddle
import paddle.fluid as fluid
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
BATCH_SIZE = 64
PASS_NUM = 5
def softmax_regression():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
predict = fluid.layers.fc(input=img, size=10, act='softmax')
return predict
def loss_net(hidden, label):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc = fluid.layers.accuracy(input=prediction, label=label)
return prediction, avg_loss, acc
def multilayer_perceptron():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
# first fully-connected layer, using ReLu as its activation function
hidden = fluid.layers.fc(input=img, size=128, act='relu')
# second fully-connected layer, using ReLu as its activation function
hidden = fluid.layers.fc(input=hidden, size=64, act='relu')
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
return prediction
def multilayer_perceptron(img, label):
img = fluid.layers.fc(input=img, size=200, act='tanh')
hidden = fluid.layers.fc(input=img, size=200, act='tanh')
return loss_net(hidden, label)
def convolutional_neural_network():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
# first conv pool
def softmax_regression(img, label):
return loss_net(img, label)
def convolutional_neural_network(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
......@@ -45,7 +51,6 @@ def convolutional_neural_network():
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
# second conv pool
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
......@@ -53,99 +58,160 @@ def convolutional_neural_network():
pool_size=2,
pool_stride=2,
act="relu")
# output layer with softmax activation function. size = 10 since there are only 10 possible digits.
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return prediction
return loss_net(conv_pool_2, label)
def train_program():
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Here we can build the prediction network in different ways. Please
# predict = softmax_regression() # uncomment for Softmax
# predict = multilayer_perceptron() # uncomment for MLP
predict = convolutional_neural_network() # uncomment for LeNet5
def train(nn_type,
use_cuda,
save_dirname=None,
model_filename=None,
params_filename=None):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
# Calculate the cost from the prediction and label.
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if nn_type == 'softmax_regression':
net_conf = softmax_regression
elif nn_type == 'multilayer_perceptron':
net_conf = multilayer_perceptron
else:
net_conf = convolutional_neural_network
prediction, avg_loss, acc = net_conf(img, label)
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
def train_test(train_test_program, train_test_feed, train_test_reader):
acc_set = []
avg_loss_set = []
for test_data in train_test_reader():
acc_np, avg_loss_np = exe.run(
program=train_test_program,
feed=train_test_feed.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val_mean = numpy.array(acc_set).mean()
avg_loss_val_mean = numpy.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
def main():
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=64)
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
use_cuda = False # set to True if training with GPU
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_program)
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
epochs = [epoch_id for epoch_id in range(PASS_NUM)]
lists = []
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print("Pass %d, Batch %d, Cost %f" % (event.step, event.epoch,
event.metrics[0]))
if isinstance(event, EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
# Train the model now
trainer.train(
num_epochs=5,
event_handler=event_handler,
reader=train_reader,
feed_order=['img', 'label'])
step = 0
for epoch_id in epochs:
for step_id, data in enumerate(train_reader()):
metrics = exe.run(
main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss, acc])
if step % 100 == 0:
print("Pass %d, Batch %d, Cost %f" % (step, epoch_id,
metrics[0]))
step += 1
# test for epoch
avg_loss_val, acc_val = train_test(
train_test_program=test_program,
train_test_reader=test_reader,
train_test_feed=feeder)
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(epoch_id, avg_loss_val, acc_val))
lists.append((epoch_id, avg_loss_val, acc_val))
if save_dirname is not None:
fluid.io.save_inference_model(
save_dirname, ["img"], [prediction],
exe,
model_filename=model_filename,
params_filename=params_filename)
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print('Best pass is %s, testing Avgcost is %s' % (best[0], best[1]))
print('The classification accuracy is %.2f%%' % (float(best[2]) * 100))
def infer(use_cuda,
save_dirname=None,
model_filename=None,
params_filename=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = numpy.array(im).reshape(1, 1, 28, 28).astype(numpy.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
img = load_image(cur_dir + '/image/infer_3.png')
inferencer = Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
tensor_img = load_image(cur_dir + '/image/infer_3.png')
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
save_dirname, exe, model_filename, params_filename)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(
inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
lab = numpy.argsort(results)
print("Inference result of image/infer_3.png is: %d" % lab[0][0][-1])
def main(use_cuda, nn_type):
model_filename = None
params_filename = None
save_dirname = "recognize_digits_" + nn_type + ".inference.model"
# call train() with is_local argument to run distributed train
train(
nn_type=nn_type,
use_cuda=use_cuda,
save_dirname=save_dirname,
model_filename=model_filename,
params_filename=params_filename)
infer(
use_cuda=use_cuda,
save_dirname=save_dirname,
model_filename=model_filename,
params_filename=params_filename)
if __name__ == '__main__':
main()
use_cuda = False
# predict = 'softmax_regression' # uncomment for Softmax
# predict = 'multilayer_perceptron' # uncomment for MLP
predict = 'convolutional_neural_network' # uncomment for LeNet5
main(use_cuda=use_cuda, nn_type=predict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册