LightPredictor.md 1.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
## LightPredictor

```c++
class LightPredictor
```

`LightPredictor`是Paddle-Lite的预测器,由`create_paddle_predictor`根据`MobileConfig`进行创建。用户可以根据LightPredictor提供的接口设置输入数据、执行模型预测、获取输出以及获得当前使用lib的版本信息等。

示例:

```python
from __future__ import print_function
from paddlelite.lite import *

# 1. 设置MobileConfig
config = MobileConfig()
config.set_model_dir(args.model_dir)

# 2. 创建LightPredictor
predictor = create_paddle_predictor(config)

# 3. 设置输入数据
input_tensor = predictor.get_input(0)
input_tensor.resize([1, 3, 224, 224])
input_tensor.set_float_data([1.] * 3 * 224 * 224)

# 4. 运行模型
predictor.run()

# 5. 获取输出数据
output_tensor = predictor.get_output(0)
print(output_tensor.shape())
print(output_tensor.float_data()[:10])
```

### `get_input(index)`

获取输入Tensor,用来设置模型的输入数据。

参数:

- `index(int)` - 输入Tensor的索引

返回:第`index`个输入`Tensor`

返回类型:`Tensor`



### `get_output(index)`

获取输出Tensor,用来获取模型的输出结果。

参数:

- `index(int)` - 输出Tensor的索引

返回:第`index`个输出`Tensor`

返回类型:`Tensor`



### `run()`

执行模型预测,需要在***设置输入数据后***调用。

参数:

- `None`

返回:`None`

返回类型:`None`



### `get_version()`

用于获取当前lib使用的代码版本。若代码有相应tag则返回tag信息,如`v2.0-beta`;否则返回代码的`branch(commitid)`,如`develop(7e44619)`

参数:

- `None`

返回:当前lib使用的代码版本信息

返回类型:`str`