未验证 提交 742ea426 编写于 作者: D daminglu 提交者: GitHub

Re-write Chapter 1 in Book to use new Fluid API (#524)

上级 3cac9f3f
......@@ -54,8 +54,9 @@ After setting up our model, there are several major steps to go through to train
Our program starts with importing necessary packages:
```python
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
import paddle
import paddle.fluid as fluid
import numpy
```
We encapsulated the [UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing) in our Python module `uci_housing`. This module can
......@@ -116,49 +117,58 @@ When training complex models, we usually have one more split: the validation set
`fit_a_line/trainer.py` demonstrates the training using [PaddlePaddle](http://paddlepaddle.org).
### Initialize PaddlePaddle
### Datafeeder Configuration
```python
paddle.init(use_gpu=False, trainer_count=1)
```
We first define data feeders for test and train. The feeder reads a `BATCH_SIZE` of data each time and feed them to the training/testing process. Users can shuffle a batch out of a `buf_size` in order to make the data random.
### Model Configuration
```python
BATCH_SIZE = 20
Linear regression is essentially a fully-connected layer with linear activation:
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
```python
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
size=1,
act=paddle.activation.Linear())
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.square_error_cost(input=y_predict, label=y)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
```
### Save Topology
### Train Program Configuration
The train_program must return the avg_loss as its first returned parameter and then use the inference_program to setup the train_program
```python
# Save the inference topology to protobuf.
inference_topology = paddle.topology.Topology(layers=y_predict)
with open("inference_topology.pkl", 'wb') as f:
inference_topology.serialize_for_inference(f)
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
```
### Create Parameters
### Specify Place
Specify your training environment, you should specify if the training is on CPU or GPU.
```python
parameters = paddle.parameters.create(cost)
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
```
### Create Trainer
The trainer will take the train_program.
```python
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
trainer = fluid.Trainer(
train_func=train_program,
place=place,
optimizer=fluid.optimizer.SGD(learning_rate=0.001))
```
### Feeding Data
......@@ -168,105 +178,92 @@ PaddlePaddle provides the
for loading the training data. A reader may return multiple columns, and we need a Python dictionary to specify the mapping from column index to data layers.
```python
feeding={'x': 0, 'y': 1}
feed_order=['x', 'y']
```
Moreover, an event handler is provided to print the training progress:
```python
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding=feeding)
print "Test %d, Cost %f" % (event.pass_id, result.cost)
```
# Specify the directory path to save the parameters
params_folder = "fit_a_line.inference.model"
```python
# event_handler to plot training and testing info
# Plot data
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
step = 0
def event_handler_plot(event):
# event_handler to print training and testing info
def event_handler(event):
global step
if isinstance(event, paddle.event.EndIteration):
if step % 10 == 0: # every 10 batches, record a train cost
plot_cost.append(train_title, step, event.cost)
if isinstance(event, fluid.EndStepEvent):
if step % 100 == 0: # every 100 batches, record a test cost
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding=feeding)
plot_cost.append(test_title, step, result.cost)
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
if step % 100 == 0: # every 100 batches, update cost plot
print(test_metrics[0])
plot_cost.append(test_title, step, test_metrics[0])
plot_cost.plot()
step += 1
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
if step >= 2000:
# Or if it has been running for enough steps
print('has been running for 2000 steps, stop')
trainer.stop()
if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
# We can save the trained parameters for the inferences later
if params_folder is not None:
trainer.save_params(params_folder)
step += 1
```
### Start Training
```python
%matplotlib inline
# The training could take up to a few minutes.
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
feeding=feeding,
event_handler=event_handler_plot,
num_passes=30)
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
```
![png](./image/train_and_test.png)
### Apply model
### Inference
Initialize the Inferencer with the inference_program and the params_folder, which is where we saved our params
#### 1. generate testing data
#### Setup the Inference Program.
Similar to the trainer.train, the Inferencer needs to take an inference_program to do inferring.
Prune the train_program to only have the y_predict.
```python
test_data_creator = paddle.dataset.uci_housing.test()
test_data = []
test_label = []
for item in test_data_creator():
test_data.append((item[0],))
test_label.append(item[1])
if len(test_data) == 5:
break
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
```
#### 2. inference
```python
# load parameters from tar file.
# users can remove the comments and change the model name
# with open('params_pass_20.tar', 'r') as f:
# parameters = paddle.parameters.Parameters.from_tar(f)
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_folder, place=place)
probs = paddle.infer(
output_layer=y_predict, parameters=parameters, input=test_data)
batch_size = 10
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
for i in xrange(len(probs)):
print "label=" + str(test_label[i][0]) + ", predict=" + str(probs[i][0])
results = inferencer.infer({'x': tensor_x})
print("infer results: ", results[0])
```
## Summary
......
......@@ -96,8 +96,9 @@ After setting up our model, there are several major steps to go through to train
Our program starts with importing necessary packages:
```python
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
import paddle
import paddle.fluid as fluid
import numpy
```
We encapsulated the [UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing) in our Python module `uci_housing`. This module can
......@@ -158,49 +159,58 @@ When training complex models, we usually have one more split: the validation set
`fit_a_line/trainer.py` demonstrates the training using [PaddlePaddle](http://paddlepaddle.org).
### Initialize PaddlePaddle
### Datafeeder Configuration
```python
paddle.init(use_gpu=False, trainer_count=1)
```
We first define data feeders for test and train. The feeder reads a `BATCH_SIZE` of data each time and feed them to the training/testing process. Users can shuffle a batch out of a `buf_size` in order to make the data random.
### Model Configuration
```python
BATCH_SIZE = 20
Linear regression is essentially a fully-connected layer with linear activation:
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
```python
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
size=1,
act=paddle.activation.Linear())
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.square_error_cost(input=y_predict, label=y)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
```
### Save Topology
### Train Program Configuration
The train_program must return the avg_loss as its first returned parameter and then use the inference_program to setup the train_program
```python
# Save the inference topology to protobuf.
inference_topology = paddle.topology.Topology(layers=y_predict)
with open("inference_topology.pkl", 'wb') as f:
inference_topology.serialize_for_inference(f)
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
```
### Create Parameters
### Specify Place
Specify your training environment, you should specify if the training is on CPU or GPU.
```python
parameters = paddle.parameters.create(cost)
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
```
### Create Trainer
The trainer will take the train_program.
```python
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
trainer = fluid.Trainer(
train_func=train_program,
place=place,
optimizer=fluid.optimizer.SGD(learning_rate=0.001))
```
### Feeding Data
......@@ -210,105 +220,92 @@ PaddlePaddle provides the
for loading the training data. A reader may return multiple columns, and we need a Python dictionary to specify the mapping from column index to data layers.
```python
feeding={'x': 0, 'y': 1}
feed_order=['x', 'y']
```
Moreover, an event handler is provided to print the training progress:
```python
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding=feeding)
print "Test %d, Cost %f" % (event.pass_id, result.cost)
```
# Specify the directory path to save the parameters
params_folder = "fit_a_line.inference.model"
```python
# event_handler to plot training and testing info
# Plot data
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
step = 0
def event_handler_plot(event):
# event_handler to print training and testing info
def event_handler(event):
global step
if isinstance(event, paddle.event.EndIteration):
if step % 10 == 0: # every 10 batches, record a train cost
plot_cost.append(train_title, step, event.cost)
if isinstance(event, fluid.EndStepEvent):
if step % 100 == 0: # every 100 batches, record a test cost
result = trainer.test(
reader=paddle.batch(
uci_housing.test(), batch_size=2),
feeding=feeding)
plot_cost.append(test_title, step, result.cost)
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
if step % 100 == 0: # every 100 batches, update cost plot
print(test_metrics[0])
plot_cost.append(test_title, step, test_metrics[0])
plot_cost.plot()
step += 1
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
if step >= 2000:
# Or if it has been running for enough steps
print('has been running for 2000 steps, stop')
trainer.stop()
if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
# We can save the trained parameters for the inferences later
if params_folder is not None:
trainer.save_params(params_folder)
step += 1
```
### Start Training
```python
%matplotlib inline
# The training could take up to a few minutes.
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
feeding=feeding,
event_handler=event_handler_plot,
num_passes=30)
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
```
![png](./image/train_and_test.png)
### Apply model
### Inference
Initialize the Inferencer with the inference_program and the params_folder, which is where we saved our params
#### 1. generate testing data
#### Setup the Inference Program.
Similar to the trainer.train, the Inferencer needs to take an inference_program to do inferring.
Prune the train_program to only have the y_predict.
```python
test_data_creator = paddle.dataset.uci_housing.test()
test_data = []
test_label = []
for item in test_data_creator():
test_data.append((item[0],))
test_label.append(item[1])
if len(test_data) == 5:
break
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
```
#### 2. inference
```python
# load parameters from tar file.
# users can remove the comments and change the model name
# with open('params_pass_20.tar', 'r') as f:
# parameters = paddle.parameters.Parameters.from_tar(f)
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_folder, place=place)
probs = paddle.infer(
output_layer=y_predict, parameters=parameters, input=test_data)
batch_size = 10
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
for i in xrange(len(probs)):
print "label=" + str(test_label[i][0]) + ", predict=" + str(probs[i][0])
results = inferencer.infer({'x': tensor_x})
print("infer results: ", results[0])
```
## Summary
......
import os
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing
with_gpu = os.getenv('WITH_GPU', '0') != '0'
def main():
# init
paddle.init(use_gpu=with_gpu, trainer_count=1)
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear())
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.square_error_cost(input=y_predict, label=y)
# Save the inference topology to protobuf.
inference_topology = paddle.topology.Topology(layers=y_predict)
with open("inference_topology.pkl", 'wb') as f:
inference_topology.serialize_for_inference(f)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0)
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)
feeding = {'x': 0, 'y': 1}
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
if isinstance(event, paddle.event.EndPass):
if event.pass_id % 10 == 0:
with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
trainer.save_parameter_to_tar(f)
result = trainer.test(
reader=paddle.batch(uci_housing.test(), batch_size=2),
feeding=feeding)
print "Test %d, Cost %f" % (event.pass_id, result.cost)
# training
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(uci_housing.train(), buf_size=500),
batch_size=2),
feeding=feeding,
event_handler=event_handler,
num_passes=30)
# inference
test_data_creator = paddle.dataset.uci_housing.test()
test_data = []
test_label = []
for item in test_data_creator():
test_data.append((item[0], ))
test_label.append(item[1])
if len(test_data) == 5:
break
# load parameters from tar file.
# users can remove the comments and change the model name
# with open('params_pass_20.tar', 'r') as f:
# parameters = paddle.parameters.Parameters.from_tar(f)
probs = paddle.infer(
output_layer=y_predict, parameters=parameters, input=test_data)
for i in xrange(len(probs)):
print "label=" + str(test_label[i][0]) + ", predict=" + str(probs[i][0])
if __name__ == '__main__':
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy
BATCH_SIZE = 20
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
# can use CPU or GPU
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program,
place=place,
optimizer=fluid.optimizer.SGD(learning_rate=0.001))
feed_order = ['x', 'y']
# Specify the directory path to save the parameters
params_folder = "fit_a_line.inference.model"
# Plot data
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
step = 0
# event_handler to print training and testing info
def event_handler(event):
global step
if isinstance(event, fluid.EndStepEvent):
if step % 100 == 0: # every 100 batches, record a test cost
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
print(test_metrics[0])
plot_cost.append(test_title, step, test_metrics[0])
plot_cost.plot()
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
if step >= 2000:
# Or if it has been running for enough steps
print('has been running for 2000 steps, stop')
trainer.stop()
# We can save the trained parameters for the inferences later
if params_folder is not None:
trainer.save_params(params_folder)
step += 1
# The training could take up to a few minutes.
trainer.train(
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path=params_folder, place=place)
batch_size = 10
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
results = inferencer.infer({'x': tensor_x})
print("infer results: ", results[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册