提交 3a24bdbd 编写于 作者: L lujun

update low level api to book, ses-1, test=develop

上级 fa35415f
......@@ -103,17 +103,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
import paddle
import paddle.fluid as fluid
import numpy
import math
import sys
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
我们通过uci_housing模块引入了数据集合[UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing)
......@@ -123,7 +115,7 @@ except ImportError:
1. 数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
2. [数据预处理](#数据预处理)的过程。
接下来我们定义了用于训练和测试的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
接下来我们定义了用于训练的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
```python
BATCH_SIZE = 20
......@@ -132,28 +124,18 @@ train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
```
### 配置训练程序
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回`平均损失`作为第一个返回值,因为它会被后面反向传播算法所用到。
```python
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return avg_loss
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
```
### Optimizer Function 配置
......@@ -161,8 +143,8 @@ def train_program():
在下面的 `SGD optimizer``learning_rate` 是训练的速度,与网络的训练收敛速度有关系。
```python
def optimizer_program():
return fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
```
### 定义运算场所
......@@ -173,112 +155,137 @@ use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
```
### 创建训练器
训练器会读入一个训练程序和一些必要的其他参数:
除此之外,还可以定义一个事件响应器来处理类似`打印训练进程`的事件:
```python
trainer = Trainer(
train_func=train_program,
place=place,
optimizer_func=optimizer_program)
# Plot data
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
def event_handler(title, loop_step, handler_val):
plot_cost.append(title, loop_step, handler_val)
plot_cost.plot()
```
### 开始提供数据
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序
### 创建训练过程
训练需要有一个训练程序和一些必要参数,并构建了一个获取训练过程中测试误差的函数
```python
feed_order=['x', 'y']
exe = fluid.Executor(place)
num_epochs = 100
# For training test cost
def train_test(train_program, feeder):
exe_test = fluid.Executor(place)
accumulated = 1 * [0]
count = 0
test_program = train_program.clone(for_test=True)
for data_test in test_reader():
outs = exe_test.run(program=test_program,
feed=feeder.feed(data_test),
fetch_list=[avg_loss])
accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)]
count += 1
return [x_d / count for x_d in accumulated]
```
除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件:
### 训练主循环
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序。我们构建一个循环来进行训练,直到训练结果足够好或者循环次数足够多。
如果训练顺利,可以把训练参数保存到`params_dirname`
```python
# Specify the directory to save the parameters
params_dirname = "fit_a_line.inference.model"
# main train loop.
def train_loop(main_program):
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
feeder_test = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(fluid.default_startup_program())
step = 0
for pass_id in range(num_epochs):
for data_train in train_reader():
avg_loss_value, = exe.run(main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
if step % 10 == 0: # record a train cost every 10 batches
event_handler(train_title, step, avg_loss_value[0])
if step % 100 == 0: # record a test cost every 100 batches
test_metics = train_test(train_program=main_program,
feeder=feeder_test)
event_handler(test_title, step, test_metics[0])
# If the accuracy is good enough, we can stop the training.
if test_metics[0] < 10.0:
return
train_title = "Train cost"
test_title = "Test cost"
step = 0
# event_handler prints training and testing info
def event_handler(event):
global step
if isinstance(event, EndStepEvent):
if step % 10 == 0: # record a train cost every 10 batches
print("%s, Step %d, Cost %f" % (train_title, step, event.metrics[0]))
if step % 100 == 0: # record a test cost every 100 batches
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
print("%s, Step %d, Cost %f" % (test_title, step, test_metrics[0]))
step += 1
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
step += 1
if isinstance(event, EndEpochEvent):
if event.epoch % 10 == 0:
# We can save the trained parameters for the inferences later
if params_dirname is not None:
trainer.save_params(params_dirname)
if math.isnan(float(avg_loss_value)):
sys.exit("got NaN loss, training failed.")
if params_dirname is not None:
# We can save the trained parameters for the inferences later
fluid.io.save_inference_model(params_dirname, ['x'],
[y_predict], exe)
```
### 开始训练
我们现在可以通过调用`trainer.train()`来开始训练
```python
%matplotlib inline
# The training could take up to a few minutes.
trainer.train(
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
train_loop(fluid.default_main_program())
```
## 预测
提供一个`inference_program`和一个`params_dirname`来初始化预测器。`params_dirname`用来存储我们的参数。
### 设定预测程序
类似于`trainer.train`,预测器需要一个预测程序来做预测。我们可以稍加修改我们的训练程序来把预测值包含进来。
需要构建一个使用训练好的参数来进行预测的程序,训练好的参数位置在`params_dirname`
### 准备预测环境
类似于训练过程,预测器需要一个预测程序来做预测。我们可以稍加修改我们的训练程序来把预测值包含进来。
```python
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
infer_exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
```
### 预测
预测器会从`params_dirname`中读取已经训练好的模型,来对从未遇见过的数据进行预测。
通过fluid.io.load_inference_model,预测器会从`params_dirname`中读取已经训练好的模型,来对从未遇见过的数据进行预测。
```python
inferencer = Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
batch_size = 10
test_reader = paddle.batch(paddle.dataset.uci_housing.test(),batch_size=batch_size)
test_data = next(test_reader())
test_x = numpy.array([data[0] for data in test_data]).astype("float32")
test_y = numpy.array([data[1] for data in test_data]).astype("float32")
results = inferencer.infer({'x': test_x})
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
print("\nground truth:")
for idx, val in enumerate(test_y):
print("%d: %.2f" % (idx, val))
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(params_dirname, exe)
batch_size = 10
infer_reader = paddle.batch(
paddle.dataset.uci_housing.test(), batch_size=batch_size)
infer_data = next(infer_reader())
infer_feat = numpy.array(
[data[0] for data in infer_data]).astype("float32")
infer_label = numpy.array(
[data[1] for data in infer_data]).astype("float32")
assert feed_target_names[0] == 'x'
results = infer_exe.run(inference_program,
feed={feed_target_names[0]: numpy.array(infer_feat)},
fetch_list=fetch_targets)
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
print("\nground truth:")
for idx, val in enumerate(infer_label):
print("%d: %.2f" % (idx, val))
```
## 总结
......
......@@ -145,17 +145,9 @@ $$MSE=\frac{1}{n}\sum_{i=1}^{n}{(\hat{Y_i}-Y_i)}^2$$
import paddle
import paddle.fluid as fluid
import numpy
import math
import sys
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
我们通过uci_housing模块引入了数据集合[UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing)
......@@ -165,7 +157,7 @@ except ImportError:
1. 数据下载的过程。下载数据保存在~/.cache/paddle/dataset/uci_housing/housing.data。
2. [数据预处理](#数据预处理)的过程。
接下来我们定义了用于训练和测试的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
接下来我们定义了用于训练的数据提供器。提供器每次读入一个大小为`BATCH_SIZE`的数据批次。如果用户希望加一些随机性,她可以同时定义一个批次大小和一个缓存大小。这样的话,每次数据提供器会从缓存中随机读取批次大小那么多的数据。
```python
BATCH_SIZE = 20
......@@ -174,28 +166,18 @@ train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
```
### 配置训练程序
训练程序的目的是定义一个训练模型的网络结构。对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。更加复杂的结果,比如卷积神经网络,递归神经网络等会在随后的章节中介绍。训练程序必须返回`平均损失`作为第一个返回值,因为它会被后面反向传播算法所用到。
```python
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return avg_loss
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
```
### Optimizer Function 配置
......@@ -203,8 +185,8 @@ def train_program():
在下面的 `SGD optimizer`,`learning_rate` 是训练的速度,与网络的训练收敛速度有关系。
```python
def optimizer_program():
return fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
```
### 定义运算场所
......@@ -215,112 +197,137 @@ use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
```
### 创建训练器
训练器会读入一个训练程序和一些必要的其他参数:
除此之外,还可以定义一个事件响应器来处理类似`打印训练进程`的事件:
```python
trainer = Trainer(
train_func=train_program,
place=place,
optimizer_func=optimizer_program)
# Plot data
from paddle.v2.plot import Ploter
train_title = "Train cost"
test_title = "Test cost"
plot_cost = Ploter(train_title, test_title)
def event_handler(title, loop_step, handler_val):
plot_cost.append(title, loop_step, handler_val)
plot_cost.plot()
```
### 开始提供数据
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序
### 创建训练过程
训练需要有一个训练程序和一些必要参数,并构建了一个获取训练过程中测试误差的函数
```python
feed_order=['x', 'y']
exe = fluid.Executor(place)
num_epochs = 100
# For training test cost
def train_test(train_program, feeder):
exe_test = fluid.Executor(place)
accumulated = 1 * [0]
count = 0
test_program = train_program.clone(for_test=True)
for data_test in test_reader():
outs = exe_test.run(program=test_program,
feed=feeder.feed(data_test),
fetch_list=[avg_loss])
accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)]
count += 1
return [x_d / count for x_d in accumulated]
```
除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件:
### 训练主循环
PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读取数据者会一次提供多列数据,因此我们需要一个Python的list来定义读取顺序。我们构建一个循环来进行训练,直到训练结果足够好或者循环次数足够多。
如果训练顺利,可以把训练参数保存到`params_dirname`。
```python
# Specify the directory to save the parameters
params_dirname = "fit_a_line.inference.model"
# main train loop.
def train_loop(main_program):
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
feeder_test = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(fluid.default_startup_program())
step = 0
for pass_id in range(num_epochs):
for data_train in train_reader():
avg_loss_value, = exe.run(main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
if step % 10 == 0: # record a train cost every 10 batches
event_handler(train_title, step, avg_loss_value[0])
if step % 100 == 0: # record a test cost every 100 batches
test_metics = train_test(train_program=main_program,
feeder=feeder_test)
event_handler(test_title, step, test_metics[0])
# If the accuracy is good enough, we can stop the training.
if test_metics[0] < 10.0:
return
train_title = "Train cost"
test_title = "Test cost"
step = 0
# event_handler prints training and testing info
def event_handler(event):
global step
if isinstance(event, EndStepEvent):
if step % 10 == 0: # record a train cost every 10 batches
print("%s, Step %d, Cost %f" % (train_title, step, event.metrics[0]))
if step % 100 == 0: # record a test cost every 100 batches
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
print("%s, Step %d, Cost %f" % (test_title, step, test_metrics[0]))
step += 1
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
step += 1
if isinstance(event, EndEpochEvent):
if event.epoch % 10 == 0:
# We can save the trained parameters for the inferences later
if params_dirname is not None:
trainer.save_params(params_dirname)
if math.isnan(float(avg_loss_value)):
sys.exit("got NaN loss, training failed.")
if params_dirname is not None:
# We can save the trained parameters for the inferences later
fluid.io.save_inference_model(params_dirname, ['x'],
[y_predict], exe)
```
### 开始训练
我们现在可以通过调用`trainer.train()`来开始训练
```python
%matplotlib inline
# The training could take up to a few minutes.
trainer.train(
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
train_loop(fluid.default_main_program())
```
## 预测
提供一个`inference_program`和一个`params_dirname`来初始化预测器。`params_dirname`用来存储我们的参数
### 设定预测程序
类似于`trainer.train`,预测器需要一个预测程序来做预测我们可以稍加修改我们的训练程序来把预测值包含进来
需要构建一个使用训练好的参数来进行预测的程序训练好的参数位置在`params_dirname`。
### 准备预测环境
类似于训练过程预测器需要一个预测程序来做预测我们可以稍加修改我们的训练程序来把预测值包含进来
```python
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
infer_exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
```
### 预测
预测器会从`params_dirname`中读取已经训练好的模型来对从未遇见过的数据进行预测
通过fluid.io.load_inference_model预测器会从`params_dirname`中读取已经训练好的模型来对从未遇见过的数据进行预测
```python
inferencer = Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
batch_size = 10
test_reader = paddle.batch(paddle.dataset.uci_housing.test(),batch_size=batch_size)
test_data = next(test_reader())
test_x = numpy.array([data[0] for data in test_data]).astype("float32")
test_y = numpy.array([data[1] for data in test_data]).astype("float32")
results = inferencer.infer({'x': test_x})
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
print("\nground truth:")
for idx, val in enumerate(test_y):
print("%d: %.2f" % (idx, val))
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(params_dirname, exe)
batch_size = 10
infer_reader = paddle.batch(
paddle.dataset.uci_housing.test(), batch_size=batch_size)
infer_data = next(infer_reader())
infer_feat = numpy.array(
[data[0] for data in infer_data]).astype("float32")
infer_label = numpy.array(
[data[1] for data in infer_data]).astype("float32")
assert feed_target_names[0] == 'x'
results = infer_exe.run(inference_program,
feed={feed_target_names[0]: numpy.array(infer_feat)},
fetch_list=fetch_targets)
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
print("\nground truth:")
for idx, val in enumerate(infer_label):
print("%d: %.2f" % (idx, val))
```
## 总结
......
......@@ -13,122 +13,137 @@
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import sys
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
import numpy
import math
import sys
BATCH_SIZE = 20
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=BATCH_SIZE)
# event_handler prints training and testing info
def event_handler(title, loop_step, handler_val):
print("%s, Step %d, Cost %f" % (title, loop_step, handler_val))
test_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=BATCH_SIZE)
def main():
def train_program():
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
batch_size = 20
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.uci_housing.test(), buf_size=500),
batch_size=batch_size)
# feature vector of length 13
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
def optimizer_program():
return fluid.optimizer.SGD(learning_rate=0.001)
# can use CPU or GPU
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_program)
feed_order = ['x', 'y']
# Specify the directory to save the parameters
params_dirname = "fit_a_line.inference.model"
train_title = "Train cost"
test_title = "Test cost"
step = 0
# event_handler prints training and testing info
def event_handler(event):
global step
if isinstance(event, EndStepEvent):
if step % 10 == 0: # record a train cost every 10 batches
print("%s, Step %d, Cost %f" %
(train_title, step, event.metrics[0]))
if step % 100 == 0: # record a test cost every 100 batches
test_metrics = trainer.test(
reader=test_reader, feed_order=feed_order)
print("%s, Step %d, Cost %f" % (test_title, step, test_metrics[0]))
if test_metrics[0] < 10.0:
# If the accuracy is good enough, we can stop the training.
print('loss is less than 10.0, stop')
trainer.stop()
step += 1
if isinstance(event, EndEpochEvent):
if event.epoch % 10 == 0:
main_program = fluid.default_main_program()
star_program = fluid.default_startup_program()
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
test_program = main_program.clone(for_test=True)
# can use CPU or GPU
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Specify the directory to save the parameters
params_dirname = "fit_a_line.inference.model"
num_epochs = 100
# For training test cost
def train_test(program, feeder):
exe_test = fluid.Executor(place)
accumulated = 1 * [0]
count = 0
for data_test in test_reader():
outs = exe_test.run(
program=program,
feed=feeder.feed(data_test),
fetch_list=[avg_loss])
accumulated = [x_c[0] + x_c[1][0] for x_c in zip(accumulated, outs)]
count += 1
return [x_d / count for x_d in accumulated]
# main train loop.
def train_loop():
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
feeder_test = fluid.DataFeeder(place=place, feed_list=[x, y])
exe.run(star_program)
train_title = "Train cost"
test_title = "Test cost"
step = 0
for pass_id in range(num_epochs):
for data_train in train_reader():
avg_loss_value, = exe.run(
main_program,
feed=feeder.feed(data_train),
fetch_list=[avg_loss])
if step % 10 == 0: # record a train cost every 10 batches
event_handler(train_title, step, avg_loss_value[0])
test_metics = train_test(
program=test_program, feeder=feeder_test)
event_handler(test_title, step, test_metics[0])
# If the accuracy is good enough, we can stop the training.
if test_metics[0] < 10.0:
return
step += 1
if math.isnan(float(avg_loss_value)):
sys.exit("got NaN loss, training failed.")
if params_dirname is not None:
# We can save the trained parameters for the inferences later
if params_dirname is not None:
trainer.save_params(params_dirname)
fluid.io.save_inference_model(params_dirname, ['x'], [y_predict],
exe)
train_loop()
# The training could take up to a few minutes.
trainer.train(
reader=train_reader,
num_epochs=100,
event_handler=event_handler,
feed_order=feed_order)
infer_exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
# infer
with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(params_dirname, exe)
batch_size = 10
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
infer_reader = paddle.batch(
paddle.dataset.uci_housing.test(), batch_size=batch_size)
infer_data = next(infer_reader())
infer_feat = numpy.array(
[data[0] for data in infer_data]).astype("float32")
infer_label = numpy.array(
[data[1] for data in infer_data]).astype("float32")
inferencer = Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
assert feed_target_names[0] == 'x'
results = infer_exe.run(
inference_program,
feed={feed_target_names[0]: numpy.array(infer_feat)},
fetch_list=fetch_targets)
batch_size = 10
test_reader = paddle.batch(
paddle.dataset.uci_housing.test(), batch_size=batch_size)
test_data = next(test_reader())
test_x = numpy.array([data[0] for data in test_data]).astype("float32")
test_y = numpy.array([data[1] for data in test_data]).astype("float32")
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
results = inferencer.infer({'x': test_x})
print("\nground truth:")
for idx, val in enumerate(infer_label):
print("%d: %.2f" % (idx, val))
print("infer results: (House Price)")
for idx, val in enumerate(results[0]):
print("%d: %.2f" % (idx, val))
print("\nground truth:")
for idx, val in enumerate(test_y):
print("%d: %.2f" % (idx, val))
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册