swig_py_paddle_cn.rst 2.4 KB
Newer Older
L
Luo Tao 已提交
1 2
.. _api_swig_py_paddle:

H
hedaoyuan 已提交
3 4 5
基于Python的预测
================

L
Luo Tao 已提交
6 7
预测流程
--------
H
hedaoyuan 已提交
8 9 10 11 12 13

PaddlePaddle使用swig对常用的预测接口进行了封装,通过编译会生成py_paddle软件包,安装该软件包就可以在python环境下实现模型预测。可以使用python的 ``help()`` 函数查询软件包相关API说明。

基于Python的模型预测,主要包括以下五个步骤。

1. 初始化PaddlePaddle环境
14 15 16

   在程序开始阶段,通过调用 ``swig_paddle.initPaddle()`` 并传入相应的命令行参数初始化PaddlePaddle。

H
hedaoyuan 已提交
17
2. 解析模型配置文件
18 19 20
   
   初始化之后,可以通过调用 ``parse_config()`` 解析训练模型时用的配置文件。注意预测数据通常不包含label, 同时预测网络通常直接输出最后一层的结果而不是像训练网络一样再接一层cost layer,所以一般需要对训练用的模型配置文件稍作相应修改才能在预测时使用。

H
hedaoyuan 已提交
21
3. 构造paddle.GradientMachine
22 23 24
  
   通过调用 ``swig_paddle.GradientMachine.createFromConfigproto()`` 传入上一步解析出来的模型配置就可以创建一个 ``GradientMachine``。

H
hedaoyuan 已提交
25
4. 准备预测数据
26 27 28
  
   swig_paddle中的预测接口的参数是自定义的C++数据类型,py_paddle里面提供了一个工具类 ``DataProviderConverter`` 可以用于接收和PyDataProvider2一样的输入数据并转换成预测接口所需的数据类型。

H
hedaoyuan 已提交
29
5. 模型预测
30 31
  
   通过调用 ``forwardTest()`` 传入预测数据,直接返回计算结果。
H
hedaoyuan 已提交
32 33


L
Luo Tao 已提交
34 35
预测Demo
--------
H
hedaoyuan 已提交
36 37 38

如下是一段使用mnist model来实现手写识别的预测代码。完整的代码见 ``src_root/doc/ui/predict/predict_sample.py`` 。mnist model可以通过 ``src_root\demo\mnist`` 目录下的demo训练出来。

L
liaogang 已提交
39
..  literalinclude:: src/predict_sample.py
H
hedaoyuan 已提交
40 41
    :language: python
    :lines: 15-18,121-136
H
hedaoyuan 已提交
42 43 44


Demo预测输出如下,其中value即为softmax层的输出。由于TEST_DATA包含两条预测数据,所以输出的value包含两个向量 。
Z
zhangjinchao01 已提交
45 46 47

..  code-block:: text

48 49
    [{'id': None, 'value': array(
      [[  5.53018653e-09,   1.12194102e-05,   1.96644767e-09,
Z
zhangjinchao01 已提交
50 51 52 53 54 55 56 57
          1.43630644e-02,   1.51111044e-13,   9.85625684e-01,
          2.08823112e-10,   2.32777140e-08,   2.00186201e-09,
          1.15501715e-08],
       [  9.99982715e-01,   1.27787406e-10,   1.72296313e-05,
          1.49316648e-09,   1.36540484e-11,   6.93137714e-10,
          2.70634608e-08,   3.48565123e-08,   5.25639710e-09,
          4.48684503e-08]], dtype=float32)}]

H
hedaoyuan 已提交
58