提交 97fee5d0 编写于 作者: W wangkuiyi 提交者: GitHub

Merge pull request #121 from wangkuiyi/merge_mnist_new_program_into_english

Insert new MNIST example program from the Chinese README.md into the English version
......@@ -115,15 +115,8 @@ For more information, please refer to [Activation functions on Wikipedia](https:
## Data Preparation
### Data Download
PaddlePaddle provides a Python module, `paddle.dataset.mnist`, which downloads and caches the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The cache is under `/home/username/.cache/paddle/dataset/mnist`:
Execute the following command to download the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset and unzip. Add paths to the training set and the test set to train.list and test.list respectively for PaddlePaddle to read.
```bash
./data/get_mnist_data.sh
```
`gzip` downloaded data. The following files can be found in `data/raw_data`:
| File name | Description |
|----------------------|-------------------------|
......@@ -132,283 +125,160 @@ Execute the following command to download the [MNIST](http://yann.lecun.com/exdb
|t10k-images-idx3-ubyte | Evaluation images, 10,000 |
|t10k-labels-idx1-ubyte | Evaluation labels, 10,000 |
Users can randomly generate 10 images with the following script (Refer to Fig. 1.)
```bash
./load_data.py
```
### Provide Data to PaddlePaddle
## Model Configuration
We use python interface to provide data to system. `mnist_provider.py` shows a complete example for training on MNIST data.
A PaddlePaddle program starts from importing the API package:
```python
# Define a py data provider
@provider(
input_types={'pixel': dense_vector(28 * 28),
'label': integer_value(10)})
def process(settings, filename): # settings is not used currently.
# Open image file
with open( filename + "-images-idx3-ubyte", "rb") as f:
# Read first 4 parameters. magic is data format. n is number of data. rows and cols are number of rows and columns, respectively
magic, n, rows, cols = struct.upack(">IIII", f.read(16))
# With empty string as a unit, read data one by one
images = np.fromfile(
f, 'ubyte',
count=n * rows * cols).reshape(n, rows, cols).astype('float32')
# Normalize data of [0, 255] to [-1,1]
images = images / 255.0 * 2.0 - 1.0
# Open label file
with open( filename + "-labels-idx1-ubyte", "rb") as l:
# Read first two parameters
magic, n = struct.upack(">II", l.read(8))
# With empty string as a unit, read data one by one
labels = np.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
import paddle.v2 as paddle
```
We want to use this program to demonstrate multiple kinds of models. Let define each of them as a Python function:
## Model Configurations
### Data Definition
In the model configuration, use `define_py_data_sources2` to define reading of data from `dataprovider`. If this configuration is used for prediction, data definition is not necessary.
```python
if not is_predict:
data_dir = './data/'
define_py_data_sources2(
train_list=data_dir + 'train.list',
test_list=data_dir + 'test.list',
module='mnist_provider',
obj='process')
```
### Algorithm Configuration
Set training related parameters.
- batch_size: use 128 samples in each training step.
- learning_rate: determines step taken in each iteration, it determines how fast the model converges.
- learning_method: use optimizer `MomentumOptimizer` for training. The parameter 0.9 indicates momentum keeps 0.9 of previous speed.
- regularization: A method to prevent overfitting. Here L2 regularization is used.
```python
settings(
batch_size=128,
learning_rate=0.1 / 128.0,
learning_method=MomentumOptimizer(0.9),
regularization=L2Regularization(0.0005 * 128))
```
### Model Architecture
#### Overview
First get reference labels from `data_layer`, and get classification results (predictions) from classifier. Here we provide three different classifiers. In training, we compute loss function, which is usually cross entropy for classification problem. In prediction, we can directly output the results (predictions).
``` python
data_size = 1 * 28 * 28
label_size = 10
img = data_layer(name='pixel', size=data_size)
predict = softmax_regression(img) # Softmax Regression
#predict = multilayer_perceptron(img) # Multilayer Perceptron
#predict = convolutional_neural_network(img) #LeNet5 Convolutional Neural Network
if not is_predict:
lbl = data_layer(name="label", size=label_size)
inputs(img, lbl)
outputs(classification_cost(input=predict, label=lbl))
else:
outputs(predict)
```
#### Softmax Regression
One simple fully connected layer with softmax activation function outputs classification result.
- softmax regression: the network has a fully-connection layer with softmax activation:
```python
def softmax_regression(img):
predict = fc_layer(input=img, size=10, act=SoftmaxActivation())
predict = paddle.layer.fc(input=img,
size=10,
act=paddle.activation.Softmax())
return predict
```
#### MultiLayer Perceptron
The following code implements a Multilayer Perceptron with two fully connected hidden layers and a ReLU activation function. The output layer has a Softmax activation function.
- multi-layer perceptron: this network has two hidden fully-connected layers, one with LeRU and the other with softmax activation:
```python
def multilayer_perceptron(img):
# First fully connected layer with ReLU
hidden1 = fc_layer(input=img, size=128, act=ReluActivation())
# Second fully connected layer with ReLU
hidden2 = fc_layer(input=hidden1, size=64, act=ReluActivation())
# Output layer as fully connected layer and softmax activation. The size must be 10.
predict = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu())
hidden2 = paddle.layer.fc(input=hidden1,
size=64,
act=paddle.activation.Relu())
predict = paddle.layer.fc(input=hidden2,
size=10,
act=paddle.activation.Softmax())
return predict
```
#### Convolutional Neural Network LeNet-5
The following is the LeNet-5 network architecture. A 2D input image is first fed into two sets of convolutional layers and pooling layers, this result is then fed to a fully connected layer, and another fully connected layer with a softmax activation.
- convolution network LeNet-5: the input image is fed through two convolution-pooling layer, a fully-connected layer, and the softmax output layer:
```python
def convolutional_neural_network(img):
# First convolutional layer - pooling layer
conv_pool_1 = simple_img_conv_pool(
conv_pool_1 = paddle.networks.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
num_channel=1,
pool_size=2,
pool_stride=2,
act=TanhActivation())
# Second convolutional layer - pooling layer
conv_pool_2 = simple_img_conv_pool(
act=paddle.activation.Tanh())
conv_pool_2 = paddle.networks.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
num_channel=20,
pool_size=2,
pool_stride=2,
act=TanhActivation())
# Fully connected layer
fc1 = fc_layer(input=conv_pool_2, size=128, act=TanhActivation())
# Output layer as fully connected layer and softmax activation. The size must be 10.
predict = fc_layer(input=fc1, size=10, act=SoftmaxActivation())
return predict
```
## Training Model
### Training Commands and Logs
1.Configure `train.sh` to execute training:
act=paddle.activation.Tanh())
```bash
config=mnist_model.py # Select network in mnist_model.py
output=./softmax_mnist_model
log=softmax_train.log
fc1 = paddle.layer.fc(input=conv_pool_2,
size=128,
act=paddle.activation.Tanh())
paddle train \
--config=$config \ # Scripts for network configuration.
--dot_period=10 \ # After `dot_period` steps, print one `.`
--log_period=100 \ # Print a log every batchs
--test_all_data_in_one_period=1 \ # Whether to use all data in every test
--use_gpu=0 \ # Whether to use GPU
--trainer_count=1 \ # Number of CPU or GPU
--num_passes=100 \ # Passes for training (One pass uses all data.)
--save_dir=$output \ # Path to saved model
2>&1 | tee $log
python -m paddle.utils.plotcurve -i $log > plot.png
predict = paddle.layer.fc(input=fc1,
size=10,
act=paddle.activation.Softmax())
return predict
```
After configuring parameters, execute `./train.sh`. Training log is as follows.
PaddlePaddle provides a special layer `layer.data` for reading data. Let us create a data layer for reading images and connect it to a classification network created using one of above three functions. We also need a cost layer for training the model.
```
I0117 12:52:29.628617 4538 TrainerInternal.cpp:165] Batch=100 samples=12800 AvgCost=2.63996 CurrentCost=2.63996 Eval: classification_error_evaluator=0.241172 CurrentEval: classification_error_evaluator=0.241172
.........
I0117 12:52:29.768741 4538 TrainerInternal.cpp:165] Batch=200 samples=25600 AvgCost=1.74027 CurrentCost=0.840582 Eval: classification_error_evaluator=0.185234 CurrentEval: classification_error_evaluator=0.129297
.........
I0117 12:52:29.916970 4538 TrainerInternal.cpp:165] Batch=300 samples=38400 AvgCost=1.42119 CurrentCost=0.783026 Eval: classification_error_evaluator=0.167786 CurrentEval: classification_error_evaluator=0.132891
.........
I0117 12:52:30.061213 4538 TrainerInternal.cpp:165] Batch=400 samples=51200 AvgCost=1.23965 CurrentCost=0.695054 Eval: classification_error_evaluator=0.160039 CurrentEval: classification_error_evaluator=0.136797
......I0117 12:52:30.223270 4538 TrainerInternal.cpp:181] Pass=0 Batch=469 samples=60000 AvgCost=1.1628 Eval: classification_error_evaluator=0.156233
I0117 12:52:30.366894 4538 Tester.cpp:109] Test samples=10000 cost=0.50777 Eval: classification_error_evaluator=0.0978
```
2.Use `plot_cost.py` to plot error curve during training.
```python
def main():
paddle.init(use_gpu=False, trainer_count=1)
```bash
python plot_cost.py softmax_train.log
```
images = paddle.layer.data(
name='pixel', type=paddle.data_type.dense_vector(784))
label = paddle.layer.data(
name='label', type=paddle.data_type.integer_value(10))
3.Use `evaluate.py ` to select the best trained model.
predict = softmax_regression(images)
#predict = multilayer_perceptron(images) # uncomment for MLP
#predict = convolutional_neural_network(images) # uncomment for LeNet5
```bash
python evaluate.py softmax_train.log
cost = paddle.layer.classification_cost(input=predict, label=label)
```
### Training Results for Softmax Regression
Now, it is time to specify training parameters. The number 0.9 in the following `Momentum` optimizer means that 90% of the current the momentum comes from the momentum of the previous iteration.
<p align="center">
<img src="image/softmax_train_log_en.png" width="400px"><br/>
Fig. 7 Softmax regression error curve<br/>
</p>
```python
parameters = paddle.parameters.create(cost)
Evaluation results of the models:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1 / 128.0,
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128))
```text
Best pass is 00013, testing Avgcost is 0.484447
The classification accuracy is 90.01%
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
```
From the evaluation results, the best pass for softmax regression model is pass-00013, where the classification accuracy is 90.01%, and the last pass-00099 has an accuracy of 89.3%. From Fig. 7, we also see that the best accuracy may not appear in the last pass. This is because during training, the model may already arrive at a local optimum, and it just swings around nearby in the following passes, or it gets a lower local optimum.
Then we specify the training data `paddle.dataset.movielens.train()` and testing data `paddle.dataset.movielens.test()`. These two functions are *reader creators*, once called, returns a *reader*. A reader is a Python function, which, once called, returns a Python generator, which yields instances of data.
### Results of Multilayer Perceptron
Here `shuffle` is a reader decorator, which takes a reader A as its parameter, and returns a new reader B, where B calls A to read in `buffer_size` data instances everytime into a buffer, then shuffles and yield instances in the buffer. If you want very shuffled data, try use a larger buffer size.
<p align="center">
<img src="image/mlp_train_log_en.png" width="400px"><br/>
Fig. 8. Multilayer Perceptron error curve<br/>
</p>
Evaluation results of the models:
`batch` is a special decorator, whose input is a reader and output is a *batch reader*, which doesn't yield an instance at a time, but a minibatch.
```text
Best pass is 00085, testing Avgcost is 0.164746
The classification accuracy is 94.95%
```python
lists = []
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test(), batch_size=128))
print "Test with Pass %d, Cost %f, %s\n" % (
event.pass_id, result.cost, result.metrics)
lists.append((event.pass_id, result.cost,
result.metrics['classification_error_evaluator']))
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
event_handler=event_handler,
num_passes=100)
```
From the evaluation results, the final training accuracy is 94.95%. It is significantly better than the softmax regression model. This is because the softmax regression is simple, and it cannot fit complex data. The Multilayer Perceptron with hidden layers has better capacity to fit complex data than the softmax regression.
### Training results for Convolutional Neural Network
<p align="center">
<img src="image/cnn_train_log_en.png" width="400px"><br/>
Fig. 9. Convolutional Neural Network error curve<br/>
</p>
Results of model evaluation:
During training, `trainer.train` invokes `event_handler` for certain events. This gives us a chance to print the training progress.
```text
Best pass is 00076, testing Avgcost is 0.0244684
The classification accuracy is 99.20%
```
From the evaluation result, the best accuracy of Convolutional Neural Network is 99.20%. So for image classification, a Convolutional Neural Network has better recognition results than a fully connected network. This is related to the local connection and parameter sharing of convolutional layers. In Fig. 9, the Convolutional Neural Network achieves good results in early steps, which indicates that it converges faster.
## Application Model
### Prediction Commands and Results
Script `predict.py` can make prediction for trained models. For example, in softmax regression:
```bash
python predict.py -c mnist_model.py -d data/raw_data/ -m softmax_mnist_model/pass-00047
# Pass 0, Batch 0, Cost 2.780790, {'classification_error_evaluator': 0.9453125}
# Pass 0, Batch 100, Cost 0.635356, {'classification_error_evaluator': 0.2109375}
# Pass 0, Batch 200, Cost 0.326094, {'classification_error_evaluator': 0.1328125}
# Pass 0, Batch 300, Cost 0.361920, {'classification_error_evaluator': 0.1015625}
# Pass 0, Batch 400, Cost 0.410101, {'classification_error_evaluator': 0.125}
# Test with Pass 0, Cost 0.326659, {'classification_error_evaluator': 0.09470000118017197}
```
- -c sets model architecture
- -d sets data for prediction
- -m sets model parameters, here the best trained model is used for prediction
Follow the instructions to input image ID for prediction. The classifier can output probabilities for each digit, predictions with the highest probability, and ground truth label.
After the training, we can check the model's prediction accuracy.
```
Input image_id [0~9999]: 3
Predicted probability of each digit:
[[ 1.00000000e+00 1.60381094e-28 1.60381094e-28 1.60381094e-28
1.60381094e-28 1.60381094e-28 1.60381094e-28 1.60381094e-28
1.60381094e-28 1.60381094e-28]]
Predict Number: 0
Actual Number: 0
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
```
From the result, this classifier recognizes the digit on the third image as digit 0 with near to 100% probability. This predicted result is consistent with the ground truth label.
Usually, with MNIST data, the softmax regression model can get accuracy around 92.34%, MLP can get about 97.66%, and convolution network can get up to around 99.20%. Convolution layers have been widely considered a great invention for image processsing.
## Conclusion
This tutorial describes a few basic Deep Learning models viz. Softmax regression, Multilayer Perceptron Network and Convolutional Neural Network. The subsequent tutorials will derive more sophisticated models from these. So it is crucial to understand these models for future learning. When our model evolved from a simple softmax regression to slightly complex Convolutional Neural Network, the recognition accuracy on the MNIST data set achieved large improvement in accuracy. This is due to the Convolutional layers' local connections and parameter sharing. While learning new models in the future, we encourage the readers to understand the key ideas that lead a new model to improve results of an old one. Moreover, this tutorial introduced the basic flow of PaddlePaddle model design, starting with a dataprovider, model layer construction, to final training and prediction. Readers can leverage the flow used in this MNIST handwritten digit classification example and experiment with different data and network architectures to train models for classification tasks of their choice.
......
......@@ -229,7 +229,11 @@ def main():
update_equation=optimizer)
```
下一步,我们开始训练过程。`paddle.dataset.movielens.train()``paddle.dataset.movielens.test()`分别做训练和测试数据集,每次训练使用的数据为128条。
下一步,我们开始训练过程。`paddle.dataset.movielens.train()``paddle.dataset.movielens.test()`分别做训练和测试数据集。这两个函数各自返回一个reader——PaddlePaddle中的reader是一个Python函数,每次调用的时候返回一个Python yield generator。
下面`shuffle`是一个reader decorator,它接受一个reader A,返回另一个reader B —— reader B 每次读入`buffer_size`条训练数据到一个buffer里,然后随机打乱其顺序,并且逐条输出。
`batch`是一个特殊的decorator,它的输入是一个reader,输出是一个batched reader —— 在PaddlePaddle里,一个reader每次yield一条训练数据,而一个batched reader每次yield一个minbatch。
```python
lists = []
......@@ -248,7 +252,7 @@ def main():
result.metrics['classification_error_evaluator']))
trainer.train(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
......@@ -258,7 +262,7 @@ def main():
训练过程是完全自动的,event_handler里打印的日志类似如下所示:
```python
```
# Pass 0, Batch 0, Cost 2.780790, {'classification_error_evaluator': 0.9453125}
# Pass 0, Batch 100, Cost 0.635356, {'classification_error_evaluator': 0.2109375}
# Pass 0, Batch 200, Cost 0.326094, {'classification_error_evaluator': 0.1328125}
......@@ -267,34 +271,7 @@ def main():
# Test with Pass 0, Cost 0.326659, {'classification_error_evaluator': 0.09470000118017197}
```
最后,选出最佳模型,并评估其效果。
```python
# find the best pass
best = sorted(lists, key=lambda list: float(list[1]))[0]
print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)
```
- softmax回归模型:分类效果最好的时候是pass-34,分类准确率为92.34%。
```python
# Best pass is 34, testing Avgcost is 0.275004139346
# The classification accuracy is 92.34%
```
- 多层感知器:最终训练的准确率为97.66%,相比于softmax回归模型有了显著的提升。原因是softmax回归模型较为简单,无法拟合更为复杂的数据,而加入了隐藏层之后的多层感知器则具有更强的拟合能力。
```python
# Best pass is 85, testing Avgcost is 0.0784368447196
# The classification accuracy is 97.66%
```
- 卷积神经网络:最好分类准确率达到惊人的99.20%。说明对于图像问题而言,卷积神经网络能够比一般的全连接网络达到更好的识别效果,而这与卷积层具有局部连接和共享权重的特性是分不开的。同时,从训练日志中可以看到,卷积神经网络在很早的时候就能达到很好的效果,说明其收敛速度非常快。
```python
# Best pass is 76, testing Avgcost is 0.0244684
# The classification accuracy is 99.20%
```
训练之后,检查模型的预测准确度。用 MNIST 训练的时候,一般 softmax回归模型的分类准确率为约为 92.34%,多层感知器为97.66%,卷积神经网络可以达到 99.20%。
## 总结
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册