未验证 提交 36fe7383 编写于 作者: J Jeff Wang 提交者: GitHub

Merge pull request #529 from jetfuel/recognize_digits_new_api_2

Recognize digits new api 2
......@@ -174,9 +174,8 @@ Let us create a data layer for reading images and connect it to the classificati
```python
def softmax_regression():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
predict = paddle.layer.fc(input=img,
size=10,
act=paddle.activation.Softmax())
predict = fluid.layers.fc(
input=img, size=10, act='softmax')
return predict
```
......@@ -221,14 +220,19 @@ def convolutional_neural_network():
```
#### Train Program Configuration
Then we need to setup the the `train_program`. It takes the prediction from the classifier first. During the training, it will calculate the `avg_loss` from the prediction.
Then we need to setup the the `train_program`. It takes the prediction from the classifier first.
During the training, it will calculate the `avg_loss` from the prediction.
**NOTE:** A train program should return an array and the first return argument has to be `avg_cost`.
The trainer always implicitly use it to calculate the gradient.
Please feel free to modify the code to test different results between `softmax regression`, `mlp`, and `convolutional neural network` classifier.
```python
def train_program():
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# predict = softmax_regression(images) # uncomment for Softmax
# predict = softmax_regression() # uncomment for Softmax
# predict = multilayer_perceptron() # uncomment for MLP
predict = convolutional_neural_network() # uncomment for LeNet5
......@@ -236,6 +240,8 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
# The first item needs to be avg_cost.
return [avg_cost, acc]
```
......@@ -273,48 +279,103 @@ trainer = fluid.Trainer(
#### Event Handler
Fluid API provides a hook to the callback function during training. Users are able to monitor training progress through mechanism.
We will demonstrate two event handlers here. Please feel free to modify on the Jupyter notebook to see the differences.
`event_handler` is used to plot some text data when training.
```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print "Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0])
if isinstance(event, fluid.EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("avg_cost: %s, acc: %s" % (avg_cost, acc))
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
```
`event_handler_plot` is used to plot a figure like below:
![png](./image/train_and_test.png)
```python
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
cost_ploter = Ploter(train_title, test_title)
step = 0
lists = []
# event_handler to plot a figure
def event_handler_plot(event):
global step
if isinstance(event, fluid.EndStepEvent):
if step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
cost_ploter.append(train_title, step, event.metrics[0])
cost_ploter.plot()
step += 1
if isinstance(event, fluid.EndEpochEvent):
# save parameters
trainer.save_params(params_dirname)
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
cost_ploter.append(test_title, step, avg_cost)
lists.append((event.epoch, avg_cost, acc))
```
#### Start training
Now that we setup the event_handler and the reader, we can start training the model. `feed_order` is used to map the data dict to the train_program
```python
# Train the model now
trainer.train(
num_epochs=1,
event_handler=event_handler,
num_epochs=5,
event_handler=event_handler_plot,
reader=train_reader,
feed_order=['img', 'label'])
```
During training, `trainer.train` invokes `event_handler` for certain events. This gives us a chance to print the training progress.
```
# Pass 0, Batch 0, Cost 2.780790, {'classification_error_evaluator': 0.9453125}
# Pass 0, Batch 100, Cost 0.635356, {'classification_error_evaluator': 0.2109375}
# Pass 0, Batch 200, Cost 0.326094, {'classification_error_evaluator': 0.1328125}
# Pass 0, Batch 300, Cost 0.361920, {'classification_error_evaluator': 0.1015625}
# Pass 0, Batch 400, Cost 0.410101, {'classification_error_evaluator': 0.125}
# Test with Pass 0, Cost 0.326659, {'classification_error_evaluator': 0.09470000118017197}
```
```
Pass 0, Batch 0, Cost 0.125650
Pass 100, Batch 0, Cost 0.161387
Pass 200, Batch 0, Cost 0.040036
Pass 300, Batch 0, Cost 0.023391
Pass 400, Batch 0, Cost 0.005856
Pass 500, Batch 0, Cost 0.003315
Pass 600, Batch 0, Cost 0.009977
Pass 700, Batch 0, Cost 0.020959
Pass 800, Batch 0, Cost 0.105560
Pass 900, Batch 0, Cost 0.239809
Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
```
After the training, we can check the model's prediction accuracy.
```
```python
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
print 'The classification accuracy is %.2f%%' % (float(best[2]) * 100)
```
Usually, with MNIST data, the softmax regression model achieves an accuracy around 92.34%, the MLP 97.66%, and the convolution network around 99.20%. Convolution layers have been widely considered a great invention for image processing.
......@@ -323,23 +384,48 @@ Usually, with MNIST data, the softmax regression model achieves an accuracy arou
After training, users can use the trained model to classify images. The following code shows how to inference MNIST images through `fluid.Inferencer`.
### Create Inferencer
The `Inferencer` takes an `infer_func` and `param_path` to setup the network and the trained parameters.
We can simply plug-in the classifier defined earlier here.
```python
inferencer = fluid.Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
#### Generate input data for inferring
batch_size = 1
import numpy
tensor_img = numpy.random.uniform(-1.0, 1.0,
[batch_size, 1, 28, 28]).astype("float32")
`infer_3.png` is an example image of the digit `3`. Turn it into an numpy array to match the data feeder format.
results = inferencer.infer({'img': tensor_img})
```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
```
print("infer results: ", results[0])
### Inference
Now we are ready to do inference.
```python
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print "Label of image/infer_3.png is: %d" % lab[0][0][-1]
```
......
......@@ -216,9 +216,8 @@ Let us create a data layer for reading images and connect it to the classificati
```python
def softmax_regression():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
predict = paddle.layer.fc(input=img,
size=10,
act=paddle.activation.Softmax())
predict = fluid.layers.fc(
input=img, size=10, act='softmax')
return predict
```
......@@ -263,14 +262,19 @@ def convolutional_neural_network():
```
#### Train Program Configuration
Then we need to setup the the `train_program`. It takes the prediction from the classifier first. During the training, it will calculate the `avg_loss` from the prediction.
Then we need to setup the the `train_program`. It takes the prediction from the classifier first.
During the training, it will calculate the `avg_loss` from the prediction.
**NOTE:** A train program should return an array and the first return argument has to be `avg_cost`.
The trainer always implicitly use it to calculate the gradient.
Please feel free to modify the code to test different results between `softmax regression`, `mlp`, and `convolutional neural network` classifier.
```python
def train_program():
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# predict = softmax_regression(images) # uncomment for Softmax
# predict = softmax_regression() # uncomment for Softmax
# predict = multilayer_perceptron() # uncomment for MLP
predict = convolutional_neural_network() # uncomment for LeNet5
......@@ -278,6 +282,8 @@ def train_program():
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
# The first item needs to be avg_cost.
return [avg_cost, acc]
```
......@@ -315,48 +321,103 @@ trainer = fluid.Trainer(
#### Event Handler
Fluid API provides a hook to the callback function during training. Users are able to monitor training progress through mechanism.
We will demonstrate two event handlers here. Please feel free to modify on the Jupyter notebook to see the differences.
`event_handler` is used to plot some text data when training.
```python
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print "Pass %d, Batch %d, Cost %f" % (
event.step, event.epoch, event.metrics[0])
if isinstance(event, fluid.EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
print("avg_cost: %s, acc: %s" % (avg_cost, acc))
print("Test with Epoch %d, avg_cost: %s, acc: %s" % (event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
```
`event_handler_plot` is used to plot a figure like below:
![png](./image/train_and_test.png)
```python
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
cost_ploter = Ploter(train_title, test_title)
step = 0
lists = []
# event_handler to plot a figure
def event_handler_plot(event):
global step
if isinstance(event, fluid.EndStepEvent):
if step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
cost_ploter.append(train_title, step, event.metrics[0])
cost_ploter.plot()
step += 1
if isinstance(event, fluid.EndEpochEvent):
# save parameters
trainer.save_params(params_dirname)
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
cost_ploter.append(test_title, step, avg_cost)
lists.append((event.epoch, avg_cost, acc))
```
#### Start training
Now that we setup the event_handler and the reader, we can start training the model. `feed_order` is used to map the data dict to the train_program
```python
# Train the model now
trainer.train(
num_epochs=1,
event_handler=event_handler,
num_epochs=5,
event_handler=event_handler_plot,
reader=train_reader,
feed_order=['img', 'label'])
```
During training, `trainer.train` invokes `event_handler` for certain events. This gives us a chance to print the training progress.
```
# Pass 0, Batch 0, Cost 2.780790, {'classification_error_evaluator': 0.9453125}
# Pass 0, Batch 100, Cost 0.635356, {'classification_error_evaluator': 0.2109375}
# Pass 0, Batch 200, Cost 0.326094, {'classification_error_evaluator': 0.1328125}
# Pass 0, Batch 300, Cost 0.361920, {'classification_error_evaluator': 0.1015625}
# Pass 0, Batch 400, Cost 0.410101, {'classification_error_evaluator': 0.125}
# Test with Pass 0, Cost 0.326659, {'classification_error_evaluator': 0.09470000118017197}
```
```
Pass 0, Batch 0, Cost 0.125650
Pass 100, Batch 0, Cost 0.161387
Pass 200, Batch 0, Cost 0.040036
Pass 300, Batch 0, Cost 0.023391
Pass 400, Batch 0, Cost 0.005856
Pass 500, Batch 0, Cost 0.003315
Pass 600, Batch 0, Cost 0.009977
Pass 700, Batch 0, Cost 0.020959
Pass 800, Batch 0, Cost 0.105560
Pass 900, Batch 0, Cost 0.239809
Test with Epoch 0, avg_cost: 0.053097883707459624, acc: 0.9822850318471338
```
After the training, we can check the model's prediction accuracy.
```
```python
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
print 'The classification accuracy is %.2f%%' % (float(best[2]) * 100)
```
Usually, with MNIST data, the softmax regression model achieves an accuracy around 92.34%, the MLP 97.66%, and the convolution network around 99.20%. Convolution layers have been widely considered a great invention for image processing.
......@@ -365,23 +426,48 @@ Usually, with MNIST data, the softmax regression model achieves an accuracy arou
After training, users can use the trained model to classify images. The following code shows how to inference MNIST images through `fluid.Inferencer`.
### Create Inferencer
The `Inferencer` takes an `infer_func` and `param_path` to setup the network and the trained parameters.
We can simply plug-in the classifier defined earlier here.
```python
inferencer = fluid.Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
```
#### Generate input data for inferring
batch_size = 1
import numpy
tensor_img = numpy.random.uniform(-1.0, 1.0,
[batch_size, 1, 28, 28]).astype("float32")
`infer_3.png` is an example image of the digit `3`. Turn it into an numpy array to match the data feeder format.
results = inferencer.infer({'img': tensor_img})
```python
# Prepare the test image
import os
import numpy as np
from PIL import Image
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
```
print("infer results: ", results[0])
### Inference
Now we are ready to do inference.
```python
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print "Label of image/infer_3.png is: %d" % lab[0][0][-1]
```
......
import os
from PIL import Image
import numpy as np
import paddle.v2 as paddle
import paddle
import paddle.fluid as fluid
with_gpu = os.getenv('WITH_GPU', '0') != '0'
def softmax_regression(img):
predict = paddle.layer.fc(
input=img, size=10, act=paddle.activation.Softmax())
def softmax_regression():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
predict = fluid.layers.fc(input=img, size=10, act='softmax')
return predict
def multilayer_perceptron(img):
# The first fully-connected layer
hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu())
# The second fully-connected layer and the according activation function
hidden2 = paddle.layer.fc(
input=hidden1, size=64, act=paddle.activation.Relu())
def multilayer_perceptron():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
# first fully-connected layer, using ReLu as its activation function
hidden = fluid.layers.fc(input=img, size=128, act='relu')
# second fully-connected layer, using ReLu as its activation function
hidden = fluid.layers.fc(input=hidden, size=64, act='relu')
# The thrid fully-connected layer, note that the hidden size should be 10,
# which is the number of unique digits
predict = paddle.layer.fc(
input=hidden2, size=10, act=paddle.activation.Softmax())
return predict
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
return prediction
def convolutional_neural_network(img):
# first conv layer
conv_pool_1 = paddle.networks.simple_img_conv_pool(
def convolutional_neural_network():
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
# first conv pool
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
num_channel=1,
pool_size=2,
pool_stride=2,
act=paddle.activation.Relu())
# second conv layer
conv_pool_2 = paddle.networks.simple_img_conv_pool(
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
# second conv pool
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
num_channel=20,
pool_size=2,
pool_stride=2,
act=paddle.activation.Relu())
# fully-connected layer
predict = paddle.layer.fc(
input=conv_pool_2, size=10, act=paddle.activation.Softmax())
return predict
act="relu")
# output layer with softmax activation function. size = 10 since there are only 10 possible digits.
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return prediction
def main():
paddle.init(use_gpu=with_gpu, trainer_count=1)
# define network topology
images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
def train_program():
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Here we can build the prediction network in different ways. Please
# choose one by uncomment corresponding line.
# predict = softmax_regression(images)
# predict = multilayer_perceptron(images)
predict = convolutional_neural_network(images)
# predict = softmax_regression() # uncomment for Softmax
# predict = multilayer_perceptron() # uncomment for MLP
predict = convolutional_neural_network() # uncomment for LeNet5
# Calculate the cost from the prediction and label.
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
acc = fluid.layers.accuracy(input=predict, label=label)
return [avg_cost, acc]
cost = paddle.layer.classification_cost(input=predict, label=label)
def main():
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=64)
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=64)
parameters = paddle.parameters.create(cost)
use_cuda = os.getenv('WITH_GPU', '0') != '0'
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1 / 128.0,
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128))
trainer = fluid.Trainer(
train_func=train_program, place=place, optimizer=optimizer)
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)
# Save the parameter into a directory. The Inferencer can load the parameters from it to do infer
params_dirname = "recognize_digits_network.inference.model"
lists = []
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
# save parameters
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
# event.metrics maps with train program return arguments.
# event.metrics[0] will yeild avg_cost and event.metrics[1] will yeild acc in this example.
print "Pass %d, Batch %d, Cost %f" % (event.step, event.epoch,
event.metrics[0])
if isinstance(event, fluid.EndEpochEvent):
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
result = trainer.test(reader=paddle.batch(
paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (
event.pass_id, result.cost, result.metrics)
lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator']))
print("Test with Epoch %d, avg_cost: %s, acc: %s" %
(event.epoch, avg_cost, acc))
# save parameters
trainer.save_params(params_dirname)
lists.append((event.epoch, avg_cost, acc))
# Train the model now
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
num_epochs=5,
event_handler=event_handler,
num_passes=5)
reader=train_reader,
feed_order=['img', 'label'])
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
print 'The classification accuracy is %.2f%%' % (float(best[2]) * 100)
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).astype(np.float32).flatten()
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
test_data = []
cur_dir = os.path.dirname(os.path.realpath(__file__))
test_data.append((load_image(cur_dir + '/image/infer_3.png'), ))
probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
lab = np.argsort(-probs) # probs and lab are the results of one batch data
print "Label of image/infer_3.png is: %d" % lab[0][0]
img = load_image(cur_dir + '/image/infer_3.png')
inferencer = fluid.Inferencer(
# infer_func=softmax_regression, # uncomment for softmax regression
# infer_func=multilayer_perceptron, # uncomment for MLP
infer_func=convolutional_neural_network, # uncomment for LeNet5
param_path=params_dirname,
place=place)
results = inferencer.infer({'img': img})
lab = np.argsort(results) # probs and lab are the results of one batch data
print "Label of image/infer_3.png is: %d" % lab[0][0][-1]
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册