未验证 提交 13fb9138 编写于 作者: Z zhouzj 提交者: GitHub

fix a bug and update docs. (#996)

上级 0706349b
...@@ -7,9 +7,10 @@ ...@@ -7,9 +7,10 @@
## 1. 准备环境 ## 1. 准备环境
### 1.1 版本要求 ### 1.1 版本要求
```bash ```bash
python>=3.7 python==3.7\3.9
PaddleSlim>=2.3.0 PaddleSlim>=2.3.0
``` ```
> 注:在 macOS 环境下要求 python==3.9; linux 环境下要求 python==3.7\3.9
### 1.2 安装 PaddleSlim ### 1.2 安装 PaddleSlim
* 通过 pip install 的方式进行安装: * 通过 pip install 的方式进行安装:
```bash ```bash
...@@ -40,7 +41,7 @@ tar -xf mobilenetv1.tar ...@@ -40,7 +41,7 @@ tar -xf mobilenetv1.tar
from paddleslim.analysis import TableLatencyPredictor from paddleslim.analysis import TableLatencyPredictor
predictor = TableLatencyPredictor(table_file='SD710') predictor = TableLatencyPredictor(table_file='SD710')
latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams, data_type='fp32') latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams', data_type='fp32')
print('predicted latency = {}ms'.format(latency)) print('predicted latency = {}ms'.format(latency))
``` ```
通过设置 table_file 来指定硬件信息,当前支持“SD625”、“SD710”、“SD845”三款骁龙芯片。 通过设置 table_file 来指定硬件信息,当前支持“SD625”、“SD710”、“SD845”三款骁龙芯片。
...@@ -61,7 +62,7 @@ print('predicted latency = {}ms'.format(latency)) ...@@ -61,7 +62,7 @@ print('predicted latency = {}ms'.format(latency))
from paddleslim.analysis import TableLatencyPredictor from paddleslim.analysis import TableLatencyPredictor
predictor = TableLatencyPredictor(table_file='SD710') predictor = TableLatencyPredictor(table_file='SD710')
predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams, data_type='int8') predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams', data_type='int8')
``` ```
## 4. 预估效果 ## 4. 预估效果
......
...@@ -248,7 +248,7 @@ def download_predictor(op_dir, op): ...@@ -248,7 +248,7 @@ def download_predictor(op_dir, op):
def load_predictor(op_type, op_dir, data_type='fp32'): def load_predictor(op_type, op_dir, data_type='fp32'):
op = op_type op = op_type
if 'conv2d' in op_type: if 'conv2d' in op_type:
op = 'conv2d_' + data_type op = f'{op_type}_{data_type}'
elif 'matmul' in op_type: elif 'matmul' in op_type:
op = 'matmul' op = 'matmul'
......
...@@ -214,7 +214,7 @@ def get_features_from_paramkey(param_key, op_type, data_type): ...@@ -214,7 +214,7 @@ def get_features_from_paramkey(param_key, op_type, data_type):
for i in range(len(X)): for i in range(len(X)):
if X[i] == '': if X[i] == '':
continue continue
features[i - 1] = int(X[i]) features[i] = int(X[i])
for i in range(0, len(Y)): for i in range(0, len(Y)):
if len(Y) == 4 and i == 0: if len(Y) == 4 and i == 0:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册