python_infer_cn.md 4.1 KB
Newer Older
1 2 3 4
# Python 预测 API介绍

Fluid提供了高度优化的[C++预测库](./native_infer.html),为了方便使用,我们也提供了C++预测库对应的Python接口,两者含义完全相同,下面是详细的使用说明

F
flame 已提交
5 6

## PaddleTensor
7

F
flame 已提交
8
`PaddleTensor`是预测库输入和输出的数据结构,包括以下字段
9

F
flame 已提交
10 11 12 13 14 15
* `name`(str): 指定输入的名称
* `shape`(tuple|list): Tensor的shape
* `data`(PaddleBuf): Tensor的数据,存储在`PaddleBuf`中,
* `dtype`(PaddleDType): Tensor的类型

## PaddleBuf
16

F
flame 已提交
17
`PaddleBuf`定义了`PaddleTensor`的存储结构,创建`PaddleBuf`:
18

F
flame 已提交
19 20 21 22 23 24
``` python
int64_buf = PaddleBuf([1, 2, 3, 4])
float_buf = PaddleBuf([1., 2., 3., 4.])
```

`PadleBuf`包括以下方法
25

F
flame 已提交
26 27 28 29 30 31 32 33
* `resize`: 重新分配内存,单位为byte
* `reset`: 重新设置数据
* `empty`: buffer是否为空
* `float_data`: 将数据转为float型的list返回
* `int64_data`: 将数据转为int64型的list返回
* `length`: 内存大小,单位为byte

## PaddleDType
34

F
flame 已提交
35
`PaddleDType`定义了`PaddleTensor`的类型,包括
36

F
flame 已提交
37 38 39 40
* `PaddleDType.INT64`: 64位整型
* `PaddleDType.FLOAT32`: 32位浮点型

## AnalysisConfig
41 42 43

`AnalysisConfig`是创建预测引擎的配置,主要包括以下方法  

F
flame 已提交
44 45 46 47 48 49 50 51 52 53 54
* `set_model`: 设置模型的路径
* `model_dir`: 返回模型路径
* `enable_use_gpu`: 设置GPU显存(单位M)和ID
* `disable_gpu`: 禁用GPU
* `gpu_device_id`: 返回使用的GPU ID
* `switch_ir_optim`: IR优化(默认开启)
* `enable_tensorrt_engine`: 启用TensorRT
* `enable_mkldnn`: 启用MKLDNN


## PaddlePredictor
55 56

`PaddlePredictor`是运行预测的引擎,下面是创建和使用的说明   
F
flame 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

``` python
# 创建预测引擎
config = AnalysisConfig(model_dir)
config.enable_use_gpu(200, 0) # 200M显存, 设备id为0
config.enable_tensorrt_engine() # 打开TensorRT

predictor = create_paddle_predictor(config)

# 设置输入
x = fluid.core.PaddleTensor()
# x.name = ...
# x.shape = ...
# x.data = ...
# x.dtype = ...

y = fluid.core.PaddleTensor()
# y.name = ...
# y.shape = ...
# y.data = ...
# y.dtype = ...


# 运行预测引擎得到结果,返回值是一个PaddleTensor的list
results = predictor.run([x, y])

# 获得 results,并应用到自己的应用中
```

**Python API 相关接口与 C++ API 完全对应,可以对照查阅**

## 完整使用示例
89

F
flame 已提交
90 91 92
下面是一个完整的resnet50预测示例

下载[resnet50模型](http://paddle-inference-dist.bj.bcebos.com/resnet50_model.tar.gz)并解压,运行如下命令将会调用预测引擎
93

F
flame 已提交
94 95 96 97 98
``` bash
python resnet50_infer.py --model_dir model --prog_file model --params_file params --batch_size 2
```

`resnet50_infer.py` 的内容是
99

F
flame 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
``` python
import argparse
import numpy as np

from paddle.fluid.core import PaddleBuf
from paddle.fluid.core import PaddleDType
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor


def main():
    args = parse_args()

    # Set config
    config = AnalysisConfig(args.model_dir)
    config.disable_gpu()

    # Create PaddlePredictor
    predictor = create_paddle_predictor(config)

    # Set inputs
    inputs = fake_input(args.batch_size)

    # Infer
    outputs = predictor.run(inputs)

    # parse outputs
    output = outputs[0]
    print(output.name)
    output_data = output.data.float_data()
    assert len(output_data) == 512 * args.batch_size
    for i in range(args.batch_size):
        print(np.argmax(output_data[i * 512:(i + 1) * 512]))


def fake_input(batch_size):
    image = PaddleTensor()
    image.name = "data"
    image.shape = [batch_size, 3, 318, 318]
    image.dtype = PaddleDType.FLOAT32
    image.data = PaddleBuf(
        np.random.randn(*image.shape).flatten().astype("float32").tolist())
    return [image]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, help="model dir")
    parser.add_argument("--prog_file", type=str, help="program filename")
    parser.add_argument("--params_file", type=str, help="parameter filename")
    parser.add_argument("--batch_size", type=int, default=1, help="batch size")

    return parser.parse_args()


if __name__ == "__main__":
    main()    
```