Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
3a24bdbd
B
book
项目概览
PaddlePaddle
/
book
通知
17
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a24bdbd
编写于
11月 26, 2018
作者:
L
lujun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update low level api to book, ses-1, test=develop
上级
fa35415f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
322 addition
and
293 deletion
+322
-293
01.fit_a_line/README.cn.md
01.fit_a_line/README.cn.md
+106
-99
01.fit_a_line/index.cn.html
01.fit_a_line/index.cn.html
+106
-99
01.fit_a_line/train.py
01.fit_a_line/train.py
+110
-95
未找到文件。
01.fit_a_line/README.cn.md
浏览文件 @
3a24bdbd
...
@@ -103,17 +103,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
...
@@ -103,17 +103,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
numpy
import
numpy
import
math
import
sys
from
__future__
import
print_function
from
__future__
import
print_function
try
:
from
paddle.fluid.contrib.trainer
import
*
from
paddle.fluid.contrib.inferencer
import
*
except
ImportError
:
print
(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib"
,
file
=
sys
.
stderr
)
from
paddle.fluid.trainer
import
*
from
paddle.fluid.inferencer
import
*
```
```
我们通过uci_housing模块引入了数据集合
[
UCI Housing Data Set
](
https://archive.ics.uci.edu/ml/datasets/Housing
)
我们通过uci_housing模块引入了数据集合
[
UCI Housing Data Set
](
https://archive.ics.uci.edu/ml/datasets/Housing
)
...
@@ -123,7 +115,7 @@ except ImportError:
...
@@ -123,7 +115,7 @@ except ImportError:
1.
数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
1.
数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
2.
[
数据预处理
](
#数据预处理
)
的过程。
2.
[
数据预处理
](
#数据预处理
)
的过程。
接下来我们定义了用于训练
和测试
的数据提供器。提供器每次读入一个大小为
`BATCH_SIZE`
的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
接下来我们定义了用于训练的数据提供器。提供器每次读入一个大小为
`BATCH_SIZE`
的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
```
python
```
python
BATCH_SIZE
=
20
BATCH_SIZE
=
20
...
@@ -132,28 +124,18 @@ train_reader = paddle.batch(
...
@@ -132,28 +124,18 @@ train_reader = paddle.batch(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
train
(),
buf_size
=
500
),
paddle
.
dataset
.
uci_housing
.
train
(),
buf_size
=
500
),
batch_size
=
BATCH_SIZE
)
batch_size
=
BATCH_SIZE
)
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
test
(),
buf_size
=
500
),
batch_size
=
BATCH_SIZE
)
```
```
### 配置训练程序
### 配置训练程序
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回
`平均损失`
作为第一个返回值,因为它会被后面反向传播算法所用到。
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回
`平均损失`
作为第一个返回值,因为它会被后面反向传播算法所用到。
```
python
```
python
def
train_program
():
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
# feature vector of length 13
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
loss
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
return
avg_loss
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_loss
=
fluid
.
layers
.
mean
(
cost
)
```
```
### Optimizer Function 配置
### Optimizer Function 配置
...
@@ -161,8 +143,8 @@ def train_program():
...
@@ -161,8 +143,8 @@ def train_program():
在下面的
`SGD optimizer`
,
`learning_rate`
是训练的速度,与网络的训练收敛速度有关系。
在下面的
`SGD optimizer`
,
`learning_rate`
是训练的速度,与网络的训练收敛速度有关系。
```
python
```
python
def
optimizer_program
():
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
return
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
avg_loss
)
```
```
### 定义运算场所
### 定义运算场所
...
@@ -173,112 +155,137 @@ use_cuda = False
...
@@ -173,112 +155,137 @@ use_cuda = False
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
```
```
### 创建训练器
除此之外,还可以定义一个事件响应器来处理类似
`打印训练进程`
的事件:
训练器会读入一个训练程序和一些必要的其他参数:
```
python
```
python
trainer
=
Trainer
(
# Plot data
train_func
=
train_program
,
from
paddle.v2.plot
import
Ploter
place
=
place
,
optimizer_func
=
optimizer_program
)
train_title
=
"Train cost"
test_title
=
"Test cost"
plot_cost
=
Ploter
(
train_title
,
test_title
)
def
event_handler
(
title
,
loop_step
,
handler_val
):
plot_cost
.
append
(
title
,
loop_step
,
handler_val
)
plot_cost
.
plot
()
```
```
###
开始提供数据
###
创建训练过程
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序
。
训练需要有一个训练程序和一些必要参数,并构建了一个获取训练过程中测试误差的函数
。
```
python
```
python
feed_order
=
[
'x'
,
'y'
]
exe
=
fluid
.
Executor
(
place
)
num_epochs
=
100
# For training test cost
def
train_test
(
train_program
,
feeder
):
exe_test
=
fluid
.
Executor
(
place
)
accumulated
=
1
*
[
0
]
count
=
0
test_program
=
train_program
.
clone
(
for_test
=
True
)
for
data_test
in
test_reader
():
outs
=
exe_test
.
run
(
program
=
test_program
,
feed
=
feeder
.
feed
(
data_test
),
fetch_list
=
[
avg_loss
])
accumulated
=
[
x_c
[
0
]
+
x_c
[
1
][
0
]
for
x_c
in
zip
(
accumulated
,
outs
)]
count
+=
1
return
[
x_d
/
count
for
x_d
in
accumulated
]
```
```
除此之外,可以定义一个事件响应器来处理类似
`打印训练进程`
的事件:
### 训练主循环
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序。我们构建一个循环来进行训练,直到训练结果足够好或者循环次数足够多。
如果训练顺利,可以把训练参数保存到
`params_dirname`
。
```
python
```
python
# Specify the directory to save the parameters
# Specify the directory to save the parameters
params_dirname
=
"fit_a_line.inference.model"
params_dirname
=
"fit_a_line.inference.model"
# main train loop.
def
train_loop
(
main_program
):
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
feeder_test
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
.
run
(
fluid
.
default_startup_program
())
step
=
0
for
pass_id
in
range
(
num_epochs
):
for
data_train
in
train_reader
():
avg_loss_value
,
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data_train
),
fetch_list
=
[
avg_loss
])
if
step
%
10
==
0
:
# record a train cost every 10 batches
event_handler
(
train_title
,
step
,
avg_loss_value
[
0
])
if
step
%
100
==
0
:
# record a test cost every 100 batches
test_metics
=
train_test
(
train_program
=
main_program
,
feeder
=
feeder_test
)
event_handler
(
test_title
,
step
,
test_metics
[
0
])
# If the accuracy is good enough, we can stop the training.
if
test_metics
[
0
]
<
10.0
:
return
train_title
=
"Train cost"
step
+=
1
test_title
=
"Test cost"
step
=
0
# event_handler prints training and testing info
def
event_handler
(
event
):
global
step
if
isinstance
(
event
,
EndStepEvent
):
if
step
%
10
==
0
:
# record a train cost every 10 batches
print
(
"%s, Step %d, Cost %f"
%
(
train_title
,
step
,
event
.
metrics
[
0
]))
if
step
%
100
==
0
:
# record a test cost every 100 batches
test_metrics
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
feed_order
)
print
(
"%s, Step %d, Cost %f"
%
(
test_title
,
step
,
test_metrics
[
0
]))
if
test_metrics
[
0
]
<
10.0
:
if
math
.
isnan
(
float
(
avg_loss_value
)):
# If the accuracy is good enough, we can stop the training.
sys
.
exit
(
"got NaN loss, training failed."
)
print
(
'loss is less than 10.0, stop'
)
if
params_dirname
is
not
None
:
trainer
.
stop
()
# We can save the trained parameters for the inferences later
step
+=
1
fluid
.
io
.
save_inference_model
(
params_dirname
,
[
'x'
],
[
y_predict
],
exe
)
if
isinstance
(
event
,
EndEpochEvent
):
if
event
.
epoch
%
10
==
0
:
# We can save the trained parameters for the inferences later
if
params_dirname
is
not
None
:
trainer
.
save_params
(
params_dirname
)
```
```
### 开始训练
### 开始训练
我们现在可以通过调用
`trainer.train()`
来开始训练
```
python
```
python
%
matplotlib
inline
%
matplotlib
inline
# The training could take up to a few minutes.
# The training could take up to a few minutes.
trainer
.
train
(
train_loop
(
fluid
.
default_main_program
())
reader
=
train_reader
,
num_epochs
=
100
,
event_handler
=
event_handler
,
feed_order
=
feed_order
)
```
```
## 预测
## 预测
提供一个
`inference_program`
和一个
`params_dirname`
来初始化预测器。
`params_dirname`
用来存储我们的参数。
需要构建一个使用训练好的参数来进行预测的程序,训练好的参数位置在
`params_dirname`
。
### 设定预测程序
类似于
`trainer.train`
,预测器需要一个预测程序来做预测。我们可以稍加修改我们的训练程序来把预测值包含进来。
### 准备预测环境
类似于训练过程,预测器需要一个预测程序来做预测。我们可以稍加修改我们的训练程序来把预测值包含进来。
```
python
```
python
def
inference_program
():
infer_exe
=
fluid
.
Executor
(
place
)
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
inference_scope
=
fluid
.
core
.
Scope
()
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
return
y_predict
```
```
### 预测
### 预测
预测器会从
`params_dirname`
中读取已经训练好的模型,来对从未遇见过的数据进行预测。
通过fluid.io.load_inference_model,
预测器会从
`params_dirname`
中读取已经训练好的模型,来对从未遇见过的数据进行预测。
```
python
```
python
inferencer
=
Inferencer
(
with
fluid
.
scope_guard
(
inference_scope
):
infer_func
=
inference_program
,
param_path
=
params_dirname
,
place
=
place
)
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
params_dirname
,
exe
)
batch_size
=
10
batch_size
=
10
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
test
(),
batch_size
=
batch_size
)
test_data
=
next
(
test_reader
())
infer_reader
=
paddle
.
batch
(
test_x
=
numpy
.
array
([
data
[
0
]
for
data
in
test_data
]).
astype
(
"float32"
)
paddle
.
dataset
.
uci_housing
.
test
(),
batch_size
=
batch_size
)
test_y
=
numpy
.
array
([
data
[
1
]
for
data
in
test_data
]).
astype
(
"float32"
)
infer_data
=
next
(
infer_reader
())
results
=
inferencer
.
infer
({
'x'
:
test_x
})
infer_feat
=
numpy
.
array
(
[
data
[
0
]
for
data
in
infer_data
]).
astype
(
"float32"
)
print
(
"infer results: (House Price)"
)
infer_label
=
numpy
.
array
(
for
idx
,
val
in
enumerate
(
results
[
0
]):
[
data
[
1
]
for
data
in
infer_data
]).
astype
(
"float32"
)
print
(
"%d: %.2f"
%
(
idx
,
val
))
assert
feed_target_names
[
0
]
==
'x'
print
(
"
\n
ground truth:"
)
results
=
infer_exe
.
run
(
inference_program
,
for
idx
,
val
in
enumerate
(
test_y
):
feed
=
{
feed_target_names
[
0
]:
numpy
.
array
(
infer_feat
)},
print
(
"%d: %.2f"
%
(
idx
,
val
))
fetch_list
=
fetch_targets
)
print
(
"infer results: (House Price)"
)
for
idx
,
val
in
enumerate
(
results
[
0
]):
print
(
"%d: %.2f"
%
(
idx
,
val
))
print
(
"
\n
ground truth:"
)
for
idx
,
val
in
enumerate
(
infer_label
):
print
(
"%d: %.2f"
%
(
idx
,
val
))
```
```
## 总结
## 总结
...
...
01.fit_a_line/index.cn.html
浏览文件 @
3a24bdbd
...
@@ -145,17 +145,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
...
@@ -145,17 +145,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
import paddle
import paddle
import paddle.fluid as fluid
import paddle.fluid as fluid
import numpy
import numpy
import math
import sys
from __future__ import print_function
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
```
我们通过uci_housing模块引入了数据集合[UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing)
我们通过uci_housing模块引入了数据集合[UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing)
...
@@ -165,7 +157,7 @@ except ImportError:
...
@@ -165,7 +157,7 @@ except ImportError:
1. 数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
1. 数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
2. [数据预处理](#数据预处理)的过程。
2. [数据预处理](#数据预处理)的过程。
接下来我们定义了用于训练
和测试
的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
接下来我们定义了用于训练的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
```python
```python
BATCH_SIZE = 20
BATCH_SIZE = 20
...
@@ -174,28 +166,18 @@ train_reader = paddle.batch(
...
@@ -174,28 +166,18 @@ train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
```
```
### 配置训练程序
### 配置训练程序
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回`平均损失`作为第一个返回值,因为它会被后面反向传播算法所用到。
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回`平均损失`作为第一个返回值,因为它会被后面反向传播算法所用到。
```python
```python
def train_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
```
```
### Optimizer Function 配置
### Optimizer Function 配置
...
@@ -203,8 +185,8 @@ def train_program():
...
@@ -203,8 +185,8 @@ def train_program():
在下面的 `SGD optimizer`,`learning_rate` 是训练的速度,与网络的训练收敛速度有关系。
在下面的 `SGD optimizer`,`learning_rate` 是训练的速度,与网络的训练收敛速度有关系。
```python
```python
def optimizer_program():
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
return fluid.optimizer.SGD(learning_rate=0.001
)
sgd_optimizer.minimize(avg_loss
)
```
```
### 定义运算场所
### 定义运算场所
...
@@ -215,112 +197,137 @@ use_cuda = False
...
@@ -215,112 +197,137 @@ use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
```
```
### 创建训练器
除此之外,还可以定义一个事件响应器来处理类似`打印训练进程`的事件:
训练器会读入一个训练程序和一些必要的其他参数:
```python
```python
trainer = Trainer(
# Plot data
train_func=train_program,
from paddle.v2.plot import Ploter
place=place,
optimizer_func=optimizer_program)
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
def event_handler(title, loop_step, handler_val):
plot_cost.append(title, loop_step, handler_val)
plot_cost.plot()
```
```
###
开始提供数据
###
创建训练过程
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序
。
训练需要有一个训练程序和一些必要参数,并构建了一个获取训练过程中测试误差的函数
。
```python
```python
feed_order=['x', 'y']
exe = fluid.Executor(place)
num_epochs = 100
# For training test cost
def train_test(train_program, feeder):
exe_test = fluid.Executor(place)
accumulated = 1 * [0]
count = 0
test_program = train_program.clone(for_test=True)
for data_test in test_reader():
outs = exe_test.run(program=test_program,
feed=feeder.feed(data_test),
fetch_list=[avg_loss])
accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)]
count += 1
return [x_d / count for x_d in accumulated]
```
```
除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件:
### 训练主循环
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序。我们构建一个循环来进行训练,直到训练结果足够好或者循环次数足够多。
如果训练顺利,可以把训练参数保存到`params_dirname`。
```python
```python
# Specify the directory to save the parameters
# Specify the directory to save the parameters
params_dirname = "fit_a_line.inference.model"
params_dirname = "fit_a_line.inference.model"
# main train loop.
def train_loop(main_program):
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
feeder_test = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(fluid.default_startup_program())
step = 0
for pass_id in range(num_epochs):
for data_train in train_reader():
avg_loss_value, = exe.run(main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
if step % 10 == 0: # record a train cost every 10 batches
event_handler(train_title, step, avg_loss_value[0])
if step % 100 == 0: # record a test cost every 100 batches
test_metics = train_test(train_program=main_program,
feeder=feeder_test)
event_handler(test_title, step, test_metics[0])
# If the accuracy is good enough, we can stop the training.
if test_metics[0]
<
10.0
:
return
train_title = "Train cost"
step
+=
1
test_title = "Test cost"
step = 0
# event_handler prints training and testing info
def event_handler(event):
global step
if isinstance(event, EndStepEvent):
if step % 10 == 0: # record a train cost every 10 batches
print("%s, Step %d, Cost %f" % (train_title, step, event.metrics[0]))
if step % 100 == 0: # record a test cost every 100 batches
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
print("%s, Step %d, Cost %f" % (test_title, step, test_metrics[0]))
if test_metrics[0]
<
10.0
:
if
math.isnan
(
float
(
avg_loss_value
))
:
#
If
the
accuracy
is
good
enough
,
we
can
stop
the
training.
sys.exit
("
got
NaN
loss
,
training
failed.
")
print
('
loss
is
less
than
10.0,
stop
')
if
params_dirname
is
not
None:
trainer.stop
()
#
We
can
save
the
trained
parameters
for
the
inferences
later
step
+=
1
fluid.io.save_inference_model
(
params_dirname
,
['
x
'],
[
y_predict
],
exe
)
if
isinstance
(
event
,
EndEpochEvent
)
:
if
event.epoch
%
10
==
0
:
#
We
can
save
the
trained
parameters
for
the
inferences
later
if
params_dirname
is
not
None:
trainer.save_params
(
params_dirname
)
```
```
###
开始训练
###
开始训练
我们现在可以通过调用
`
trainer.train
()`
来开始训练
```
python
```
python
%
matplotlib
inline
%
matplotlib
inline
#
The
training
could
take
up
to
a
few
minutes.
#
The
training
could
take
up
to
a
few
minutes.
trainer.train
(
train_loop
(
fluid.default_main_program
())
reader=
train_reader,
num_epochs=
100,
event_handler=
event_handler,
feed_order=
feed_order)
```
```
##
预测
##
预测
提供一个
`
inference_program
`
和一个
`
params_dirname
`
来初始化预测器
。`
params_dirname
`
用来存储我们的参数
。
需要构建一个使用训练好的参数来进行预测的程序
,
训练好的参数位置在
`
params_dirname
`。
###
设定预测程序
类似于
`
trainer.train
`,
预测器需要一个预测程序来做预测
。
我们可以稍加修改我们的训练程序来把预测值包含进来
。
###
准备预测环境
类似于训练过程
,
预测器需要一个预测程序来做预测
。
我们可以稍加修改我们的训练程序来把预测值包含进来
。
```
python
```
python
def
inference_program
()
:
infer_exe =
fluid.Executor(place)
x =
fluid.layers.data(name='x',
shape=
[13],
dtype=
'float32'
)
inference_scope =
fluid.core.Scope()
y_predict =
fluid.layers.fc(input=x,
size=
1,
act=
None)
return
y_predict
```
```
###
预测
###
预测
预测器会从
`
params_dirname
`
中读取已经训练好的模型
,
来对从未遇见过的数据进行预测
。
通过fluid.io.load_inference_model
,
预测器会从
`
params_dirname
`
中读取已经训练好的模型
,
来对从未遇见过的数据进行预测
。
```
python
```
python
inferencer =
Inferencer(
with
fluid.scope_guard
(
inference_scope
)
:
infer_func=
inference_program,
param_path=
params_dirname,
place=
place)
[
inference_program
,
feed_target_names
,
fetch_targets] =
fluid.io.load_inference_model(params_dirname,
exe
)
batch_size =
10
batch_size =
10
test_reader =
paddle.batch(paddle.dataset.uci_housing.test(),batch_size=batch_size)
test_data =
next(test_reader())
infer_reader =
paddle.batch(
test_x =
numpy.array([data[0]
for
data
in
test_data
]).
astype
("
float32
")
paddle.dataset.uci_housing.test
(),
batch_size=
batch_size)
test_y =
numpy.array([data[1]
for
data
in
test_data
]).
astype
("
float32
")
infer_data =
next(infer_reader())
results =
inferencer.infer({'x':
test_x
})
infer_feat =
numpy.array(
[
data
[0]
for
data
in
infer_data
]).
astype
("
float32
")
print
("
infer
results:
(
House
Price
)")
infer_label =
numpy.array(
for
idx
,
val
in
enumerate
(
results
[0])
:
[
data
[1]
for
data
in
infer_data
]).
astype
("
float32
")
print
("%
d:
%.2
f
"
%
(
idx
,
val
))
assert
feed_target_names[0] =
=
'
x
'
print
("\
nground
truth:
")
results =
infer_exe.run(inference_program,
for
idx
,
val
in
enumerate
(
test_y
)
:
feed=
{feed_target_names[0]:
numpy.array
(
infer_feat
)},
print
("%
d:
%.2
f
"
%
(
idx
,
val
))
fetch_list=
fetch_targets)
print
("
infer
results:
(
House
Price
)")
for
idx
,
val
in
enumerate
(
results
[0])
:
print
("%
d:
%.2
f
"
%
(
idx
,
val
))
print
("\
nground
truth:
")
for
idx
,
val
in
enumerate
(
infer_label
)
:
print
("%
d:
%.2
f
"
%
(
idx
,
val
))
```
```
##
总结
##
总结
...
...
01.fit_a_line/train.py
浏览文件 @
3a24bdbd
...
@@ -13,122 +13,137 @@
...
@@ -13,122 +13,137 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
sys
try
:
from
paddle.fluid.contrib.trainer
import
*
from
paddle.fluid.contrib.inferencer
import
*
except
ImportError
:
print
(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib"
,
file
=
sys
.
stderr
)
from
paddle.fluid.trainer
import
*
from
paddle.fluid.inferencer
import
*
import
numpy
import
numpy
import
math
import
sys
BATCH_SIZE
=
20
train_reader
=
paddle
.
batch
(
# event_handler prints training and testing info
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
train
(),
buf_size
=
500
),
def
event_handler
(
title
,
loop_step
,
handler_val
):
batch_size
=
BATCH_SIZE
)
print
(
"%s, Step %d, Cost %f"
%
(
title
,
loop_step
,
handler_val
)
)
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
test
(),
buf_size
=
500
),
batch_size
=
BATCH_SIZE
)
def
main
():
def
train_program
():
batch_size
=
20
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
train
(),
buf_size
=
500
),
batch_size
=
batch_size
)
test_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
uci_housing
.
test
(),
buf_size
=
500
),
batch_size
=
batch_size
)
# feature vector of length 13
# feature vector of length 13
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
loss
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
main_program
=
fluid
.
default_main_program
()
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
star_program
=
fluid
.
default_startup_program
()
return
avg_loss
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_loss
=
fluid
.
layers
.
mean
(
cost
)
def
optimizer_program
():
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
return
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
avg_loss
)
test_program
=
main_program
.
clone
(
for_test
=
True
)
# can use CPU or GPU
use_cuda
=
False
# can use CPU or GPU
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
use_cuda
=
False
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
trainer
=
Trainer
(
exe
=
fluid
.
Executor
(
place
)
train_func
=
train_program
,
place
=
place
,
optimizer_func
=
optimizer_program
)
# Specify the directory to save the parameters
feed_order
=
[
'x'
,
'y'
]
params_dirname
=
"fit_a_line.inference.model"
num_epochs
=
100
# Specify the directory to save the parameters
params_dirname
=
"fit_a_line.inference.model"
# For training test cost
def
train_test
(
program
,
feeder
):
train_title
=
"Train cost"
exe_test
=
fluid
.
Executor
(
place
)
test_title
=
"Test cost"
accumulated
=
1
*
[
0
]
count
=
0
step
=
0
for
data_test
in
test_reader
():
outs
=
exe_test
.
run
(
program
=
program
,
# event_handler prints training and testing info
feed
=
feeder
.
feed
(
data_test
),
def
event_handler
(
event
):
fetch_list
=
[
avg_loss
])
global
step
accumulated
=
[
x_c
[
0
]
+
x_c
[
1
][
0
]
for
x_c
in
zip
(
accumulated
,
outs
)]
if
isinstance
(
event
,
EndStepEvent
):
count
+=
1
if
step
%
10
==
0
:
# record a train cost every 10 batches
return
[
x_d
/
count
for
x_d
in
accumulated
]
print
(
"%s, Step %d, Cost %f"
%
(
train_title
,
step
,
event
.
metrics
[
0
]))
# main train loop.
if
step
%
100
==
0
:
# record a test cost every 100 batches
def
train_loop
():
test_metrics
=
trainer
.
test
(
reader
=
test_reader
,
feed_order
=
feed_order
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
print
(
"%s, Step %d, Cost %f"
%
(
test_title
,
step
,
test_metrics
[
0
]))
feeder_test
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
if
test_metrics
[
0
]
<
10.0
:
exe
.
run
(
star_program
)
# If the accuracy is good enough, we can stop the training.
print
(
'loss is less than 10.0, stop'
)
train_title
=
"Train cost"
trainer
.
stop
()
test_title
=
"Test cost"
step
+=
1
step
=
0
if
isinstance
(
event
,
EndEpochEvent
):
for
pass_id
in
range
(
num_epochs
):
if
event
.
epoch
%
10
==
0
:
for
data_train
in
train_reader
():
avg_loss_value
,
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data_train
),
fetch_list
=
[
avg_loss
])
if
step
%
10
==
0
:
# record a train cost every 10 batches
event_handler
(
train_title
,
step
,
avg_loss_value
[
0
])
test_metics
=
train_test
(
program
=
test_program
,
feeder
=
feeder_test
)
event_handler
(
test_title
,
step
,
test_metics
[
0
])
# If the accuracy is good enough, we can stop the training.
if
test_metics
[
0
]
<
10.0
:
return
step
+=
1
if
math
.
isnan
(
float
(
avg_loss_value
)):
sys
.
exit
(
"got NaN loss, training failed."
)
if
params_dirname
is
not
None
:
# We can save the trained parameters for the inferences later
# We can save the trained parameters for the inferences later
if
params_dirname
is
not
None
:
fluid
.
io
.
save_inference_model
(
params_dirname
,
[
'x'
],
[
y_predict
],
trainer
.
save_params
(
params_dirnam
e
)
ex
e
)
train_loop
()
# The training could take up to a few minutes.
infer_exe
=
fluid
.
Executor
(
place
)
trainer
.
train
(
inference_scope
=
fluid
.
core
.
Scope
()
reader
=
train_reader
,
num_epochs
=
100
,
event_handler
=
event_handler
,
feed_order
=
feed_order
)
# infer
with
fluid
.
scope_guard
(
inference_scope
):
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
params_dirname
,
exe
)
batch_size
=
10
def
inference_program
():
infer_reader
=
paddle
.
batch
(
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
paddle
.
dataset
.
uci_housing
.
test
(),
batch_size
=
batch_size
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
return
y_predict
infer_data
=
next
(
infer_reader
())
infer_feat
=
numpy
.
array
(
[
data
[
0
]
for
data
in
infer_data
]).
astype
(
"float32"
)
infer_label
=
numpy
.
array
(
[
data
[
1
]
for
data
in
infer_data
]).
astype
(
"float32"
)
inferencer
=
Inferencer
(
assert
feed_target_names
[
0
]
==
'x'
infer_func
=
inference_program
,
param_path
=
params_dirname
,
place
=
place
)
results
=
infer_exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
numpy
.
array
(
infer_feat
)},
fetch_list
=
fetch_targets
)
batch_size
=
10
print
(
"infer results: (House Price)"
)
test_reader
=
paddle
.
batch
(
for
idx
,
val
in
enumerate
(
results
[
0
]):
paddle
.
dataset
.
uci_housing
.
test
(),
batch_size
=
batch_size
)
print
(
"%d: %.2f"
%
(
idx
,
val
))
test_data
=
next
(
test_reader
())
test_x
=
numpy
.
array
([
data
[
0
]
for
data
in
test_data
]).
astype
(
"float32"
)
test_y
=
numpy
.
array
([
data
[
1
]
for
data
in
test_data
]).
astype
(
"float32"
)
results
=
inferencer
.
infer
({
'x'
:
test_x
})
print
(
"
\n
ground truth:"
)
for
idx
,
val
in
enumerate
(
infer_label
):
print
(
"%d: %.2f"
%
(
idx
,
val
))
print
(
"infer results: (House Price)"
)
for
idx
,
val
in
enumerate
(
results
[
0
]):
print
(
"%d: %.2f"
%
(
idx
,
val
))
print
(
"
\n
ground truth:"
)
if
__name__
==
'__main__'
:
for
idx
,
val
in
enumerate
(
test_y
):
main
()
print
(
"%d: %.2f"
%
(
idx
,
val
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录