未验证 提交 f89fc93c 编写于 作者: Z zhouzj 提交者: GitHub

analysis: add op predictors for unknown operators (#933)

* add op predictors for unknown operators

* modify the dowload path of op predictors and opt tools

* add automatic download mode for opt tools

* modify predictor's test files

* modify predictor's test file

* modify predictor's test file

* modify predictor's test files

* add requirements for installing sklearn

* modify predicor's test files

* modify predictor's test files

* modify predictor's test files

* modify predictor's test files

* modify predictor's test files

* Modify api design

* update LatencyPredictor's doc.

* update LatencyPredictor's test files

* update LatencyPredictor's test files

* update LatencyPredictor's test files

* update LatencyPredictor's test files

* update LatencyPredictor's doc.

* update LatencyPredictor's doc.

* update LatencyPredictor's doc.

* update some syntax

* add LatencyPredictor's api doc.

* update LatencyPredictor's doc.

* update LatencyPredictor's api doc.

* adjust \'predictor_state\' api

* update LatencyPredictor's test files
Co-authored-by: NminghaoBD <79566150+minghaoBD@users.noreply.github.com>
上级 7b49bed0
LatencyPredictor
===============
.. toctree::
:maxdepth: 1
predictor_api.rst
延时预估器
================
TableLatencyPredictor
---------------------
.. py:class:: paddleslim.TableLatencyPredictor(table_file)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/analysis/latency_predictor.py>`_
延时预估器用于预估模型在特定硬件设备上的推理延时。在无需部署模型到实际环境的情况下,可以快速预估出多种部署环境和设置下的推理延时。
**参数:**
- **table_file(str)** - 指定硬件设备,可选“SD625”、“SD710”、“SD845”;或是传入已有的延时表路径。
**返回:** 一个TableLatencyPredictor类的实例。
**示例代码:**
.. code-block:: python
import paddle
from paddleslim import TableLatencyPredictor
predictor = TableLatencyPredictor(table_file='SD710')
..
.. py:method:: paddleslim.TableLatencyPredictor.predict(model_file, param_file, data_type, threads, input_shape)
预估模型在指定硬件设备上的延时。
**参数:**
- **model_file(str)** - 推理模型的模型文件路径。
- **param_file(str)** - 推理模型的参数文件路径。
- **data_type(str)** - 推理模型的数据类型:‘fp32’或‘int8’。
- **threads(int)** - 设置预估多少线程数下的延时。目前只支持4线程,后续将支持更多线程数。
- **input_shape(list)** - 当模型为可变长输入时,该参数设置其输入形状。目前,暂不支持使用该参数控制模型输入,需在保存推理模型时设置确切的输入形状。
**返回:**
- **latency(float)** - 推理模型在指定设备上的延时。
**示例代码:**
.. code-block:: python
import paddle
from paddleslim import TableLatencyPredictor
from paddle.vision.models import mobilenet_v1
from paddle.static import InputSpec
predictor = TableLatencyPredictor(table_file='SD710')
model = mobilenet_v1()
x_spec = InputSpec(shape=[1, 3, 224, 224], dtype='float32', name='inputs')
static_model = paddle.jit.to_static(model, input_spec=[x_spec])
paddle.jit.save(static_model, 'mobilenet_v1')
latency = predictor.predict(model_file='mobilenet_v1.pdmodel',
param_file='mobilenet_v1.pdiparams',
data_type='fp32')
print("predicted latency:", latency)
..
\ No newline at end of file
# LatencyPredictor使用教程
LatencyPredictor主要功能是根据提供的op-latency映射表,预估神经网络网络在特定硬件设备上的实际耗时。它基于Paddle-Lite开发,适用于使用Paddle-Lite部署的模型。映射表以key-value的形式存储,key包含了神经网络模型经过Paddle-Lite图优化后的各种融合op信息,value则代表在特定硬件上的实际耗时。
延时预估器(LatencyPredictor)用于预估模型在特定硬件设备上的推理延时。在无需部署模型到实际环境的情况下,可以快速预估出多种部署环境和设置下的推理延时。当前,
* 支持所有可以使用 Paddle Lite 部署的模型;
* 支持预估 ARM CPU 上的模型耗时。
## 使用方法
## 1. 准备环境
### 1.1 版本要求
```bash
python>=3.7
PaddleSlim>=2.3.0
```
### 1.2 安装 PaddleSlim
* 通过 pip install 的方式进行安装:
```bash
pip install paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple
```
1. 下载或自行编译opt优化工具
2. 构建LatencyPredictor
3. 定义模型和预测
* 或者从源码安装最新版 PaddleSlim:
### 1. 下载或自行编译opt优化工具
1.1 下载提供的opt工具,可根据运行环境下载适用的opt,目前提供Mac平台([M1芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_M1_mac)[Intel芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_intel_mac))和[Ubuntu](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_ubuntu)平台的opt工具下载。
1.2 也可以自行通过Paddle-Lite源码编译opt工具,具体请参考请参考Paddle-Lite[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html)。编译时需要关闭Paddle-Lite的内存复用功能,即注释掉这[几行代码](https://github.com/PaddlePaddle/Paddle-Lite/blob/d76f45be989d3e01cebf2ac18e047cfd37d52666/lite/core/optimizer/optimizer.cc#L266-L268)
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd Paddleslim
python3.7 -m pip install -r requirements.txt # 从requirements.txt安装依赖库
python3.7 setup.py install
```
### 2. 构建LatencyPredictor
## 2. 快速开始
### 2.1 准备推理模型
延时预估器通过读取推理模型文件(\*.pdmodel, \*.pdiparams)进行预估。以 MobileNetv1 为例,请从[这里](https://bj.bcebos.com/v1/paddlemodels/PaddleSlim/analysis/mobilenetv1.tar)下载推理模型文件。
```bash
wget https://bj.bcebos.com/v1/paddlemodels/PaddleSlim/analysis/mobilenetv1.tar
tar -xf mobilenetv1.tar
```
使用自定义模型结构时,可参考[api文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save)保存推理模型。
提供opt工具路径,以及芯片和测试参数信息,LatencyPredictor会根据这些参数自动下载对应的映射表。如下所示,芯片为845芯片,测试线程数threads为4,测速模式power_mode为3,测试batchsize为1.
### 2.2 预估推理延时
构造 TableLatencyPredictor 类实例,并调用 predict 函数预估推理模型的延时。
```
import paddleslim
opt_path = {opt工具路径}
predictor = paddleslim.TableLatencyPredictor(opt_path, hardware='845', threads=4, power_mode=3, batchsize=1)
predictor = paddleslim.TableLatencyPredictor(table_file='SD710')
latency = predictor.predict(model_file='mobilenetv1_fp32.pdmodel', param_file='mobilenetv1_fp32.pdiparams, data_type='fp32')
print('predicted latency = {}ms'.format(latency))
```
通过设置 table_file 来指定硬件信息,当前支持“SD625”、“SD710”、“SD845”三款骁龙芯片。
> 注1:保存推理模型时设置确切的输入形状;
>
> 注2:暂时不支持可变长输入,后续将会添加该功能。
## 3. 更多特性
### 3.1 丰富的预估模式
### 3. 定义模型和预测
预估模型延时有两种方式:
* 查表:根据已有的延时表,查找推理模型中每个算子(op)的延时,从而预估模型整体延时。优点是面对表中已覆盖的模型能实现快速准确查找,缺点是面对新模型束手无策;
* 预测器:构建了 op 级别的预测器,作为延时表的补充,用于预估延时表中未覆盖的op,从而实现对任意模型进行延时预估。
> op 预测器只预估 batchsize=1 的延时,支持 SD625 和 SD710 设备,后续将在更多设备上扩充不同 batchsize 的 op 预测器。
定义model后可通过predict_latency函数直接预测模型推理耗时,其中,input_shape为输入大小,save_dir为中间pbmodel模型保存路径,data_type可选fp32或int8,task_type=‘cls'表示该模型为分类模型。
### 3.2 支持预估 INT8 模型
延时预估器支持对 INT8 量化模型进行延时预估,仅需提供 INT8 量化保存的推理模型文件,并将在调用 predict 函数时,设置 data_type='int8',如下所示:
```
import paddle
from paddle.vision.models import mobilenet_v1
import paddleslim
model = mobilenet_v1()
latency = predictor.predict_latency(model, input_shape=[1,3,224,224], save_dir='./model', data_type='int8', task_type='cls')
print('predicted latency = {}ms'.format(latency))
predictor = paddleslim.TableLatencyPredictor(table_file='SD710')
predictor.predict(model_file='mobilenetv1_int8.pdmodel', param_file='mobilenetv1_int8.pdiparams, data_type='int8')
```
## 4. 预估效果
延时预估器在 SD625、SD710 等设备上的测速设置都是线程数 threads 为4,测速模式 power_mode 为 0,涵盖了 PaddleClas、PaddleDetection 中的移动端模型,后续将支持其他线程数。下表展示了对典型分类、检测模型在 SD710 的预估效果,预估延时误差均小于 10%。
&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;<strong>表1: SD710 预测结果</strong>
| Model | Predict(ms) | Real(ms) | Error(%) |
|:-----:|:----------------------------:|:---------------------:|:--------------------------:|
| MobileNetV1_x0_25| 3.856 | 4.082 | 5.552 |
| MobileNetV1_x0_5| 11.456 | 11.804 | 2.948 |
| MobilenetV1| 39.107 | 39.448 | 0.865 |
| MobileNetV2_x0_5| 9.905 | 10.470 | 5.395 |
| MobilenetV2 | 26.666 | 27.542 | 3.183 |
| MobileNetV2_x2_0 | 86.281 | 86.824 | 0.625 |
| MobileNetV3_large_x0_35 | 6.428 | 6.911 | 6.984 |
| MobileNetV3_large_x1_0 | 21.566 | 23.108 | 6.673 |
| MobileNetV3_large_x1_25 | 32.888 | 33.641 | 2.236 |
| GhostNet_x0_5 | 8.294 | 9.182 | 9.675 |
| GhostNet_x1_0 | 18.603 | 19.916 | 6.594 |
| GhostNet_x1_3 | 26.896 | 28.0525 | 4.120 |
| ShuffleNetV2_x1_0 | 13.199 | 14.476 | 8.825 |
| ShuffleNetV2_x1_5 | 23.066 | 25.082 | 8.038 |
| ShuffleNetV2_x2_0 | 41.379 | 43.868 | 5.674 |
| ppyolo_mbv3_large_coco | 70.055 | 72.063 | 2.787 |
| ppyolo_tiny_650e_coco | 43.808 | 45.3393 | 3.377 |
| picodet_l_320_coco | 92.603 | 92.926 | 0.347 |
| picodet_m_320_coco | 69.176 | 65.778 | 4.911 |
| picodet_s_320_coco | 38.874 | 36.999 | 4.823 |
......@@ -15,7 +15,8 @@ from .flops import flops, dygraph_flops
from .model_size import model_size
from .latency import LatencyEvaluator, TableLatencyEvaluator
from .latency_predictor import LatencyPredictor, TableLatencyPredictor
from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model
from .parse_ops import get_key_from_op
from ._utils import save_cls_model, save_det_model, save_seg_model
__all__ = [
'flops',
......
......@@ -14,242 +14,46 @@
import os
import numpy as np
import pickle
import paddle
import paddleslim
import subprocess
import sklearn
__all__ = [
"get_key_from_op", "save_cls_model", "save_det_model", "save_seg_model"
"save_cls_model", "save_det_model", "save_seg_model", "nearest_interpolate",
"opt_model", "load_predictor"
]
def get_key_from_op(op):
"""Construct key of latency table according to the info of graph's op
"""
param_key = ''
op_type = op.type()
if 'conv2d' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[-1].shape()
weight_shape = op.all_inputs()[-2].shape()
kernel = weight_shape[2]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
dilation = op.attr('dilations')[1]
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}'
elif op_type == 'matmul' or op_type == 'matmul_v2':
X = op.all_inputs()[0].shape()
Y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}'
elif 'batch_norm' in op_type or 'layer_norm' in op_type:
out_shape = op.all_outputs()[-1].shape()
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'pool2d' in op_type:
out_shape = op.all_outputs()[0].shape()
data = op.all_inputs()
in_shape = data[-1].shape()
kernel = op.attr('ksize')[1]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
flag_global = 1 if op.attr('global_pooling') else 0
if op.attr('adaptive') and out_shape[-1] == 1:
flag_global = 1
pooling_type = op.attr('pooling_type')
param_key = f'{op_type} in={in_shape} out={out_shape} stride={stride} kernel={kernel}x{kernel} pad={padding} flag_global={flag_global} type={pooling_type})'
elif op_type in [
'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax',
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape'
] or 'transpose' in op_type or 'interp_v2' in op_type:
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type:
param_key = f'{op_type}'
elif op_type in ['scale'] or 'reshape' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'elementwise' in op_type:
out_shape = op.all_outputs()[0].shape()
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
axis = op.attr('axis')
param_key = f'{op_type} X={x} Y={y} axis={axis} out={out_shape}'
elif op_type == 'concat':
data = op.all_inputs()
X = ""
for x in data:
X += f"{x.shape()}"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={X} axis={axis} out={out_shape}'
elif op_type == 'yolo_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
class_num = op.attr('class_num')
param_key = f'{op_type} in={in_shape} out={out_shape} class_num={class_num}'
elif op_type == 'prior_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
aspect_ratios = op.attr('aspect_ratios')
max_sizes = op.attr('max_sizes')
min_sizes = op.attr('min_sizes')
param_key = f'{op_type} in={in_shape} out={out_shape} aspect_ratios={aspect_ratios} max_sizes={max_sizes} min_sizes={min_sizes}'
elif op_type == 'slice':
in_shape = op.all_inputs()[-1].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes}'
elif op_type == 'stack':
data = op.all_inputs()
X = "["
for x in data:
X += f"{x.shape()}"
X += "]"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={X} axis={axis} out={out_shape}'
elif op_type == 'exp':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
decrease_axis = op.attr('decrease_axis')
ends = op.attr('ends')
param_key = f'{op_type} in={in_shape} out={out_shape} axes={axes} decrease_axis={decrease_axis} ends={ends}'
elif op_type in ['multiclass_nms3', 'matrix_nms']:
boxs = op.all_inputs()[0].shape()
scores = op.all_inputs()[-1].shape()
keep_top_k = op.attr('keep_top_k')
nms_top_k = op.attr('nms_top_k')
param_key = f'{op_type} boxs={boxs} scores={scores} keep_top_k={keep_top_k} nms_top_k={nms_top_k}'
elif op_type == 'dropout':
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type == 'fc':
in_shape = op.all_inputs()[-2].shape()
weight_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape}'
elif op_type == 'shuffle_channel':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
group = op.attr('group')
param_key = f'{op_type} in={in_shape} group={group} out={out_shape}'
elif op_type == 'split':
in_shape = op.all_inputs()[-1].shape()
axis = op.attr('axis')
sections = op.attr('sections')
param_key = f'{op_type} in={in_shape} axis={axis} sections={sections}'
elif op_type in ['unsqueeze2', 'squeeze2']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes} out={out_shape}'
elif op_type == 'flatten_contiguous_range':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
start_axis = op.attr('start_axis')
stop_axis = op.attr(' stop_axis')
param_key = f'{op_type} in={in_shape} start_axis={start_axis} stop_axis={stop_axis} out={out_shape}'
elif op_type == 'sum':
in_shape1 = op.all_inputs()[0].shape()
in_shape2 = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}'
elif op_type in ['calib', 'floor']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif op_type == 'uniform_random':
shape = op.attr('shape')
param_key = f'{op_type} shape={shape}'
elif op_type == 'greater_equal':
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={x} Y={y} out={out_shape}'
elif op_type == 'reduce_mean':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
dim = op.attr('dim')
param_key = f'{op_type} in={in_shape} out={out_shape} dim={dim}'
elif 'pad3d' in op_type:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
paddings = op.attr('paddings')
param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}'
def opt_model(opt="paddle_lite_opt",
model_file='',
param_file='',
optimize_out_type='protobuf',
valid_targets='arm'):
assert os.path.exists(model_file) and os.path.exists(
param_file), f'{model_file} or {param_file} is not existed.'
save_dir = f'./opt_models_tmp/{os.getpid()}'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
assert optimize_out_type in ['protobuf', 'naive_buffer']
if optimize_out_type == 'protobuf':
model_out = os.path.join(save_dir, 'pbmodel')
else:
model_out = os.path.join(save_dir, 'model')
elif op_type in ['feed', 'fetch']:
pass
cmd = f'{opt} --model_file={model_file} --param_file={param_file} --optimize_out_type={optimize_out_type} --optimize_out={model_out} --valid_targets={valid_targets}'
print(f'commands:{cmd}')
m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = m.communicate()
print(out, 'opt done!')
if optimize_out_type == 'protobuf':
model_out = os.path.join(model_out, 'model')
else:
print(op)
print(op._op)
raise KeyError(f'The "{op_type}" has never seen.')
return param_key
model_out = model_out + '.nb'
return model_out
def sample_generator(input_shape, batch_num):
......@@ -399,3 +203,57 @@ def save_seg_model(model, input_shape, save_dir, data_type):
param_file = f'{save_dir}.pdiparams'
return model_file, param_file
def nearest_interpolate(features, data):
def distance(x, y):
x = np.array(x)
y = np.array(y)
return np.sqrt(np.sum(np.square(x - y)))
if len(data) <= 0:
return None
data_features = data[:, 0:-1]
latency = data[:, -1]
idx = 0
dist = distance(features, data_features[0])
for i in range(1, len(data_features)):
cur_dist = distance(features, data_features[i])
if cur_dist < dist:
idx = i
dist = cur_dist
return latency[idx]
def dowload_predictor(op_dir, op):
"""Dowload op predictors' model file
Args:
op_dir(str): the dowload path of op predictor. Actually, it's the hardware information.
op(str): the op type.
Returns:
op_path: The path of the file.
"""
if not os.path.exists(op_dir):
os.makedirs(op_dir)
op_path = os.path.join(op_dir, op + '_predictor.pkl')
if not os.path.exists(op_path):
subprocess.call(
f'wget -P {op_dir} https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{op_path}',
shell=True)
return op_path
def load_predictor(op_type, op_dir, data_type='fp32'):
op = op_type
if 'conv2d' in op_type:
op = 'conv2d_' + data_type
elif 'matmul' in op_type:
op = 'matmul'
op_path = dowload_predictor(op_dir, op)
with open(op_path, 'rb') as f:
model = pickle.load(f)
return model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
from .parse_ops import get_key_from_op
__all__ = ["get_data_from_tables", "get_features_from_paramkey"]
def cal_flops_params(op_type, cin, cout, kernel=1, h=1, w=1):
# cin: weight[1]
if 'conv' in op_type:
params = cout * (kernel * kernel * cin + 1)
flops = 2 * kernel * kernel * h * w * cin * cout
return flops, params
elif "fc" in op_type:
flops = 2 * cin * cout
params = (cin + 1) * cout
return flops, params
def get_data_from_tables(table_dict, op_type, data_type='fp32'):
data = []
for param_key in table_dict:
cur_type = param_key.split()[0]
if op_type == cur_type:
features = get_features_from_paramkey(param_key, op_type, data_type)
if features == None:
continue
features.append(table_dict[param_key])
data.append(features)
return np.array(data)
def get_features_from_paramkey(param_key, op_type, data_type):
"""Get op's parameters according to the key of latency table
"""
features = []
if 'conv2d' in op_type:
flag_quant = 'quant=None' if data_type == 'fp32' else 'quant=True'
if flag_quant not in param_key:
return None
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
outputs = re.search(r'out=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cout = int(weight[0])
cin = int(weight[1])
kernel = int(weight[2])
out_h = int(outputs[2])
out_w = int(outputs[3])
stride = int(re.search(r'stride=\d*', param_key).group().split('=')[1])
group = int(re.search(r'group=\d*', param_key).group().split('=')[1])
pad = int(re.search(r'pad=\d', param_key).group().split('=')[1])
flops, params = cal_flops_params('conv', cin, cout, kernel, out_h,
out_w)
if data_type == 'fp32':
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
in_c = int(inputs[1])
in_h = int(inputs[2])
in_w = int(inputs[3])
features = [
in_c, cout, kernel, group, stride, pad, in_h * in_w,
out_h * out_w
]
else:
features = [
cin, cout, kernel, group, stride, pad, out_h * out_w, flops,
params
]
elif 'matmul' in op_type:
X = re.search(r'X=(\(-*\d*, \d*\))',
param_key).group().split('=')[-1].strip('('
')').split(', ')
Y = re.search(r'Y=(\(\d*, \d*\))',
param_key).group().split('=')[-1].strip('('
')').split(', ')
a = int(X[0])
b = int(Y[0])
c = int(Y[1])
flops, params = cal_flops_params('fc', b, c)
features = [b, c, flops, params]
elif ('batch_norm' in op_type or 'layer_norm' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
features = [0, 0, 0]
for i in range(1, len(inputs)):
if inputs[i] == '':
continue
features[i - 1] = int(inputs[i])
elif 'pool2d' in op_type:
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
outputs = re.search(r'out=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(inputs[1])
in_h = int(inputs[2])
in_w = int(inputs[3])
out_h = int(outputs[2])
out_w = int(outputs[3])
kernel = int(
re.search(r'kernel=\d*x*\d*', param_key).group().split('x')[-1])
flag_global = int(
re.search(r'flag_global=\d', param_key).group().split('=')[-1])
if flag_global:
kernel = in_h
stride = int(re.search(r'stride=\d', param_key).group().split('=')[-1])
pad = int(re.search(r'pad=\d', param_key).group().split('=')[-1])
flag_type = 1 if 'type=avg' in param_key else 0
features = [
cin, kernel, stride, pad, in_h * in_w, out_h * out_w, flag_type
]
elif ('reshape' in op_type or 'scale' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
outputs = re.search(r'out=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
# inputs[4], ouputs[4]
features = [0, 0, 0, 0, 0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
for i in range(len(outputs)):
if outputs[i] == '':
continue
features[i + 4] = int(outputs[i])
elif ('hard_swish' in op_type or 'relu' in op_type or
'leaky_relu' in op_type or 'tanh' in op_type or 'swish' in op_type or
'softmax' in op_type or 'hard_sigmoid' in op_type or
'sigmoid' in op_type or 'gelu' in op_type or 'clip' in op_type or
'shape' in op_type or 'transpose' in op_type or
'interp_v2' in op_type):
inputs = re.search(r'in=(\((-?\d+,* *)+\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
#cin, h, w
cin = int(inputs[1])
in_h = 0
in_w = 0
if len(inputs) == 4:
in_h = int(inputs[2])
in_w = int(inputs[3])
features = [cin, in_h, in_w]
elif 'elementwise' in op_type:
X = re.search(r'X=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip('('
')').split(',')
Y = re.search(r'Y=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip('('
')').split(',')
# X[1] X[2] X[3] Y[1] Y[2] Y[3]
features = [0, 0, 0, 0, 0, 0]
for i in range(1, len(X)):
if X[i] == '':
continue
features[i - 1] = int(X[i])
for i in range(0, len(Y)):
if len(Y) == 4 and i == 0:
continue
if Y[i] == '':
continue
features[i + 2] = int(Y[i])
elif 'concat' in op_type:
inputs = re.search(r'in=(\((-?\d+,* *)+\))+',
param_key).group().split('=')[-1].strip(
'('
')').split(')(')
channels = []
for ins in inputs:
channels.append(int(ins.split(', ')[1]))
#hw, c1,c2,c3,c4,c5,c6,c7,c8,c9
features = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
input1 = inputs[0].split(', ')
if len(input1) == 3:
features[0] = int(input1[2])
else:
features[0] = int(input1[2]) * int(input1[3])
for i in range(len(channels)):
features[i + 1] = channels[i]
elif 'yolo_box' in op_type:
outputs = re.search(r'out=(\(-?\d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
inputs = re.search(r'in=(\(-?\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(inputs[1])
h = int(inputs[2])
w = int(inputs[3])
cout = int(outputs[1])
class_num = int(
re.search(r'class_num=\d*', param_key).group().split('=')[-1])
features = [cin, h * w, cout, class_num]
elif 'prior_box' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
cin = int(inputs[1])
h = int(inputs[2])
w = int(inputs[3])
features = [cin, h, w]
elif 'slice' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
features = [0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
elif 'exp' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
features = [0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
elif 'dropout' in param_key:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
features = [0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
elif 'fc' in op_type:
weight = re.search(r'weight=(\(\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(weight[0])
cout = int(weight[1])
flops, params = cal_flops_params('fc', cin, cout)
features = [cin, cout, flops, params]
elif 'shuffle_channel' in op_type:
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(inputs[1])
in_h = int(inputs[2])
in_w = int(inputs[3])
group = int(re.search(r'group=\d*', param_key).group().split('=')[1])
features = [cin, in_h, in_w, group]
elif 'split' in op_type:
inputs = re.search(r'in=(\(-*\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
cin = int(inputs[1])
in_h = int(inputs[2])
in_w = int(inputs[2])
features = [cin, in_h, in_w]
elif 'squeeze' in op_type:
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
features = [0, 0, 0, 0]
for i in range(len(inputs)):
if inputs[i] == '':
continue
features[i] = int(inputs[i])
elif 'flatten_contiguous_range' in op_type:
inputs = re.search(r'in=(\(-?\d*, \d*, \d*, \d*\))',
param_key).group().split('=')[-1].strip(
'('
')').split(', ')
features = [int(inputs[1]), int(inputs[2]), int(inputs[3])]
elif ('calib' in op_type or 'floor' in op_type):
inputs = re.search(r'in=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
outputs = re.search(r'out=\((-?\d+,* *)+\)',
param_key).group().split('=')[-1].strip(
'('
')').split(',')
features = [0, 0, 0, 0, 0, 0]
for i in range(1, len(inputs)):
features[i - 1] = int(inputs[i])
for i in range(1, len(outputs)):
features[i + 2] = int(outputs[i])
elif 'uniform_random' in op_type:
shape = re.search(r'shape=\[(-?\d+,* *)+\]',
param_key).group().split('=')[-1].strip(
'['
']').split(',')
features = [0, 0, 0, 0]
for i in range(len(shape)):
if shape[i] == '':
continue
features[i] = int(shape[i])
return features
......@@ -18,8 +18,9 @@ import os
import pickle
import time
import subprocess
from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model
from .parse_ops import get_key_from_op
from .extract_features import get_data_from_tables, get_features_from_paramkey
from ._utils import opt_model, load_predictor, nearest_interpolate
import paddle
import paddleslim
__all__ = ["LatencyPredictor", "TableLatencyPredictor"]
......@@ -52,140 +53,138 @@ class TableLatencyPredictor(LatencyPredictor):
"""The preditor used to get pbmodel's latency on some devices and infer engines.
Args:
table_file(str): The path of file that records the devices latency of operators.
opt_path(str): The path of opt tool to convert a paddle model to an optimized pbmodel that fuses operators.
table_file(str): The path of file that records the devices latency of operators.
"""
def __init__(self,
opt_path,
hardware='845',
threads=4,
power_mode=3,
batchsize=1):
self.table_file = f'{hardware}_threads_{threads}_power_mode_{power_mode}_batchsize_{batchsize}.pkl'
self.opt_path = opt_path
def __init__(self, table_file='SD710'):
self.table_file = table_file
self.table_dict = {}
self._read_table()
self.det_multi_input = False
def _read_table(self):
if not os.path.exists(self.table_file):
subprocess.call(
f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{self.table_file}',
shell=True)
self.hardware = None
self.threads = None
self.predictor_state = False
self._initial_table()
def _initial_table(self):
if self.table_file in ['SD625', 'SD710', 'SD845', 'SD865']:
self.hardware = self.table_file
if self.hardware in ['SD625', 'SD710']:
self.predictor_state = True
self.threads = 4
self.table_file = f'{self.hardware}_threads_4_power_mode_0.pkl'
if not os.path.exists(self.table_file):
subprocess.call(
f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}',
shell=True)
assert os.path.exists(
self.table_file), f'{self.table_file} is not existed.'
self.table_file
), f'{self.table_file} is not existed. If you want to use our table files, please set \'table_file\' in [SD625, SD710, SD845, SD865]'
with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f)
print('Successfully load {}'.format(self.table_file))
def set_det_multi_input(self, det_multi_input):
"""If a detection model has multiple input, the self.det_multi_input should be True. Default: False.
"""
self.det_multi_input = det_multi_input
print('Successfully load {}'.format(self.table_file))
def opt_model(self, model, input_shape, save_dir, data_type, task_type):
"""Convert the model graph to an optimized pbmodel by using opt tool.
Args:
model: The input model graph.
input_shape(list): The input shape of model.
save_dir: Where to save the pbmodel.
data_type: Data type, fp32 or int8.
task_type: Task type, cls, det or seg, different task models need to use different quantization strategies.
Returns:
pbmodel_file: The path of optimized pbmodel.
"""
def _change_table(self, threads=4):
assert threads == 4, 'Only 4 threads are available now.'
self.table_file = f'{self.hardware}_threads_{threads}_power_mode_0.pkl'
if not os.path.exists(self.table_file):
subprocess.call(
f'wget https://paddlemodels.bj.bcebos.com/PaddleSlim/analysis/{self.table_file}',
shell=True)
if task_type == 'cls':
model_file, param_file = save_cls_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type)
with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f)
elif task_type == 'det':
model_file, param_file = save_det_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type,
det_multi_input=self.det_multi_input)
elif task_type == 'seg':
model_file, param_file = save_seg_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type)
print('Successfully load {}'.format(self.table_file))
else:
assert task_type in ['cls', 'det', 'seg'
], f'task_type must be one of [cls, det, seg]'
pb_model = os.path.join(save_dir, f'{data_type}pbmodel')
if not os.path.exists(pb_model):
os.makedirs(pb_model)
cmd = f'{self.opt_path} --model_file={model_file} --param_file={param_file} --optimize_out_type=protobuf --optimize_out={pb_model} --valid_targets=arm'
print(f'commands:{cmd}')
m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = m.communicate()
print(out, 'opt done!')
pbmodel_file = os.path.join(pb_model, 'model')
return pbmodel_file
def predict_latency(self,
model,
input_shape=[1, 3, 224, 224],
save_dir='',
data_type='int8',
task_type='cls'):
def _get_input_shape(self, graph):
in_shape = []
for op in graph.ops():
param_key = get_key_from_op(op)
if param_key != '':
in_shape = op.all_inputs()[-1].shape()
break
return in_shape
def predict(self,
model_file,
param_file,
data_type,
threads=4,
input_shape=None):
"""predict the latency of the model
Args:
model: The input model graph.
input_shape(list): The input shape of model. Default: [1,3,224,224].
save_dir: Where to save the pbmodel.
data_type: Data type, fp32 or int8. Default : int8
task_type: Task type, cls, det or seg, different task models need to use different quantization strategies. Default: cls.
model_file(str), param_file(str): The inference model(*.pdmodel, *.pdiparams).
data_type(str): Data type, fp32 or int8. Default : fp32
threads(int): threads num
input_shape(list): Generally, the input shape is confirmed when saving the inference model and the parameter is only effective for variable length input shape.
Returns:
latency(float): The latency of the pbmodel.
latency(float): The latency of the model.
"""
assert data_type in ['fp32', 'int8'
], f'data_type must be one of [fp32, int8]'
assert task_type in ['cls', 'det', 'seg'
], f'task_type must be one of [cls, det, seg]'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if self.hardware and self.threads != threads:
self._change_table(threads)
pbmodel_file = self.opt_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type,
task_type=task_type)
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf', )
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
program = paddle.fluid.proto.framework_pb2.ProgramDesc.FromString(
program_desc_str)
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
if input_shape != None:
ori_shape = self._get_input_shape(graph)
assert ori_shape == input_shape, "The parameter \'input_shape\' dosn't work now. The input shape is confirmed when saving the inference model"
latency = 0.0
for op in graph.ops():
param_key = get_key_from_op(op)
if param_key != '':
assert param_key in self.table_dict, f'{param_key} is not in the tabel.'
if param_key == '':
continue
if param_key in self.table_dict:
latency += self.table_dict[param_key]
elif self.predictor_state:
latency += self.op_predictor(op.type(), param_key, data_type)
else:
raise AssertionError(f'{param_key} is not in the table.')
return latency
def op_predictor(self, op_type, param_key, data_type):
"""predict the latency of the operator which is not in the table
Args:
op_type: The operator's type
param_key: The operator's parameter information.
data_type: Data type, fp32 or int8. Default : int8
Returns:
latency(float): The latency of the operator.
"""
latency = 0.0
op_dir = self.table_file.split('.')[0] + '_batchsize_1'
if op_type in [
'depthwise_conv2d', 'conv2d', 'pool2d', 'matmul',
'elementwise_add', 'elementwise_mul', 'concat', 'calib', 'swish'
]:
predictor = load_predictor(op_type, op_dir, data_type)
features = get_features_from_paramkey(param_key, op_type, data_type)
latency = predictor.predict([features])
else:
data = get_data_from_tables(
table_dict=self.table_dict,
op_type=op_type,
data_type=data_type)
features = get_features_from_paramkey(param_key, op_type, data_type)
latency = nearest_interpolate(features, data)
assert latency != None, f'{param_key} is not in the table.'
return latency
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["get_key_from_op"]
def get_key_from_op(op):
"""Construct key of latency table according to the info of graph's op
"""
param_key = ''
op_type = op.type()
if 'conv2d' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[-1].shape()
weight_shape = op.all_inputs()[-2].shape()
kernel = weight_shape[2]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
dilation = op.attr('dilations')[1]
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}'
elif op_type == 'matmul' or op_type == 'matmul_v2':
X = op.all_inputs()[0].shape()
Y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}'
elif 'batch_norm' in op_type or 'layer_norm' in op_type:
out_shape = op.all_outputs()[-1].shape()
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'pool2d' in op_type:
out_shape = op.all_outputs()[0].shape()
data = op.all_inputs()
in_shape = data[-1].shape()
kernel = op.attr('ksize')[1]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
flag_global = 1 if op.attr('global_pooling') else 0
if op.attr('adaptive') and out_shape[-1] == 1:
flag_global = 1
pooling_type = op.attr('pooling_type')
param_key = f'{op_type} in={in_shape} out={out_shape} stride={stride} kernel={kernel}x{kernel} pad={padding} flag_global={flag_global} type={pooling_type})'
elif op_type in [
'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax',
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape'
] or 'transpose' in op_type or 'interp_v2' in op_type:
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type:
param_key = f'{op_type}'
elif op_type in ['scale'] or 'reshape' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'elementwise' in op_type:
out_shape = op.all_outputs()[0].shape()
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
axis = op.attr('axis')
param_key = f'{op_type} X={x} Y={y} axis={axis} out={out_shape}'
elif op_type == 'concat':
data = op.all_inputs()
X = ""
for x in data:
X += f"{x.shape()}"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={X} axis={axis} out={out_shape}'
elif op_type == 'yolo_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
class_num = op.attr('class_num')
param_key = f'{op_type} in={in_shape} out={out_shape} class_num={class_num}'
elif op_type == 'prior_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
aspect_ratios = op.attr('aspect_ratios')
max_sizes = op.attr('max_sizes')
min_sizes = op.attr('min_sizes')
param_key = f'{op_type} in={in_shape} out={out_shape} aspect_ratios={aspect_ratios} max_sizes={max_sizes} min_sizes={min_sizes}'
elif op_type == 'slice':
in_shape = op.all_inputs()[-1].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes}'
elif op_type == 'stack':
data = op.all_inputs()
X = "["
for x in data:
X += f"{x.shape()}"
X += "]"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={X} axis={axis} out={out_shape}'
elif op_type == 'exp':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
decrease_axis = op.attr('decrease_axis')
ends = op.attr('ends')
param_key = f'{op_type} in={in_shape} out={out_shape} axes={axes} decrease_axis={decrease_axis} ends={ends}'
elif op_type in ['multiclass_nms3', 'matrix_nms']:
boxs = op.all_inputs()[0].shape()
scores = op.all_inputs()[-1].shape()
keep_top_k = op.attr('keep_top_k')
nms_top_k = op.attr('nms_top_k')
param_key = f'{op_type} boxs={boxs} scores={scores} keep_top_k={keep_top_k} nms_top_k={nms_top_k}'
elif op_type == 'dropout':
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type == 'fc':
in_shape = op.all_inputs()[-2].shape()
weight_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape}'
elif op_type == 'shuffle_channel':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
group = op.attr('group')
param_key = f'{op_type} in={in_shape} group={group} out={out_shape}'
elif op_type == 'split':
in_shape = op.all_inputs()[-1].shape()
axis = op.attr('axis')
sections = op.attr('sections')
param_key = f'{op_type} in={in_shape} axis={axis} sections={sections}'
elif op_type in ['unsqueeze2', 'squeeze2']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes} out={out_shape}'
elif op_type == 'flatten_contiguous_range':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
start_axis = op.attr('start_axis')
stop_axis = op.attr(' stop_axis')
param_key = f'{op_type} in={in_shape} start_axis={start_axis} stop_axis={stop_axis} out={out_shape}'
elif op_type == 'sum':
in_shape1 = op.all_inputs()[0].shape()
in_shape2 = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}'
elif op_type in ['calib', 'floor']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif op_type == 'uniform_random':
shape = op.attr('shape')
param_key = f'{op_type} shape={shape}'
elif op_type == 'greater_equal':
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={x} Y={y} out={out_shape}'
elif op_type == 'reduce_mean':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
dim = op.attr('dim')
param_key = f'{op_type} in={in_shape} out={out_shape} dim={dim}'
elif 'pad3d' in op_type:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
paddings = op.attr('paddings')
param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}'
elif op_type in ['feed', 'fetch']:
pass
else:
print(op)
print(op._op)
raise KeyError(f'The "{op_type}" has never seen.')
return param_key
......@@ -19,17 +19,7 @@ import paddleslim
from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor
from paddle.vision.models import mobilenet_v1, mobilenet_v2
from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm
import subprocess
opt_tool = 'opt_ubuntu' # use in linux
# opt_tool = 'opt_M1_mac' # use in mac with M1 chip
# opt_tool = 'opt_intel_mac' # use in mac with intel chip
if not os.path.exists(opt_tool):
subprocess.call(
f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}',
shell=True)
subprocess.call(f'chmod +x {opt_tool}', shell=True)
from paddleslim.analysis._utils import opt_model, save_cls_model, save_seg_model, save_det_model
def channel_shuffle(x, groups):
......@@ -108,7 +98,7 @@ class ModelCase4(paddle.nn.Layer):
x = paddle.exp(x)
y += paddle.fluid.layers.uniform_random(y.shape)
y = paddle.fluid.layers.reduce_mean(y, dim=1, keep_dim=True)
return x + y
return paddle.greater_equal(x, y)
class ModelCase5(paddle.nn.Layer):
......@@ -143,29 +133,78 @@ class ModelCase5(paddle.nn.Layer):
return boxes, scores, box, var, out
class ModelCase6(paddle.nn.Layer):
def __init__(self):
super(ModelCase6, self).__init__()
self.bn1 = BatchNorm2D(3)
self.relu1 = ReLU()
self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16)
self.dp = paddle.nn.Dropout(p=0.5)
def forward(self, inputs):
x = self.bn1(inputs)
x = paddle.reshape(x, [1, 3 * 16 * 16])
x = self.fc1(x)
x = paddle.fluid.layers.unsqueeze(input=x, axes=[2])
x = self.relu1(x)
y = paddle.fluid.layers.fill_constant(
x.shape, dtype=paddle.float32, value=1)
x = paddle.stack([x, y], axis=3)
x = paddle.slice(x, axes=[0], starts=[0], ends=[1])
x = paddle.exp(x)
y += paddle.fluid.layers.uniform_random(y.shape)
y = paddle.expand(y, shape=[1, 768, 768, 2])
x = paddle.expand(x, shape=[1, 768, 768, 2])
out = paddle.concat([x, y])
out = self.dp(out)
out = channel_shuffle(out, 2)
out1, out2 = paddle.split(out, num_or_sections=2, axis=1)
return out1, out2
class ModelCase7(paddle.nn.Layer):
def __init__(self):
super(ModelCase7, self).__init__()
self.bn1 = BatchNorm2D(255)
def forward(self, inputs):
image = inputs['image']
image = self.bn1(image)
img_size = paddle.fluid.data(
name='img_size', shape=[None, 2], dtype='int64')
anchors = [10, 13, 16, 30, 33, 23]
boxes, scores = paddle.fluid.layers.yolo_box(
x=image,
img_size=img_size,
class_num=80,
anchors=anchors,
conf_thresh=0.01,
downsample_ratio=32)
box, var = paddle.fluid.layers.prior_box(
input=image, image=image, min_sizes=[2.], clip=True, flip=True)
return boxes, scores, box, var
class TestCase1(unittest.TestCase):
def test_case1(self):
paddle.disable_static()
model = mobilenet_v1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
assert latency > 0
latency = predictor.predict_latency(
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='cls')
save_dir="./inference_model",
data_type='int8')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='int8')
assert latency > 0
......@@ -173,25 +212,14 @@ class TestCase2(unittest.TestCase):
def test_case2(self):
paddle.disable_static()
model = mobilenet_v2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
assert latency > 0
latency = predictor.predict_latency(
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='cls')
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
......@@ -199,24 +227,21 @@ class TestCase3(unittest.TestCase):
def test_case3(self):
paddle.disable_static()
model = mobilenet_v2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pred = LatencyPredictor()
pbmodel_file = predictor.opt_model(
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
save_dir="./inference_model",
data_type='fp32')
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf')
pred = LatencyPredictor()
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = pred._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
......@@ -226,18 +251,14 @@ class TestCase4(unittest.TestCase):
def test_case4(self):
paddle.disable_static()
model = ModelCase1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model_file, param_file = save_cls_model(
model,
input_shape=[1, 116, 28, 28],
save_dir='./model',
data_type='fp32',
task_type='cls')
save_dir="./inference_model",
data_type='fp32')
predictor = TableLatencyPredictor(table_file='SD710')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
......@@ -245,18 +266,14 @@ class TestCase5(unittest.TestCase):
def test_case5(self):
paddle.disable_static()
model = mobilenet_v1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_seg_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='seg')
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
......@@ -264,25 +281,14 @@ class TestCase6(unittest.TestCase):
def test_case6(self):
paddle.disable_static()
model = ModelCase2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='det')
assert os.path.exists(pbmodel_file)
latency = predictor.predict_latency(
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_det_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='det')
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
......@@ -290,19 +296,15 @@ class TestCase7(unittest.TestCase):
def test_case7(self):
paddle.disable_static()
model = ModelCase3()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
predictor.set_det_multi_input(det_multi_input=True)
latency = predictor.predict_latency(
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_det_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
save_dir="./inference_model",
data_type='fp32',
task_type='det')
det_multi_input=True)
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
......@@ -310,23 +312,21 @@ class TestCase8(unittest.TestCase):
def test_case8(self):
paddle.disable_static()
model = ModelCase4()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
predictor = LatencyPredictor()
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 16, 16],
save_dir='./model',
data_type='int8',
task_type='cls')
save_dir="./inference_model",
data_type='int8')
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
......@@ -336,23 +336,21 @@ class TestCase9(unittest.TestCase):
def test_case9(self):
paddle.disable_static()
model = ModelCase5()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
predictor = LatencyPredictor()
model_file, param_file = save_det_model(
model,
input_shape=[1, 255, 13, 13],
save_dir='./model',
data_type='fp32',
task_type='det')
save_dir="./inference_model",
data_type='fp32')
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
......@@ -362,27 +360,69 @@ class TestCase10(unittest.TestCase):
def test_case10(self):
paddle.disable_static()
model = ModelCase1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
predictor = LatencyPredictor()
model_file, param_file = save_seg_model(
model,
input_shape=[1, 116, 28, 28],
save_dir='./model',
data_type='int8',
task_type='seg')
save_dir="./inference_model",
data_type='int8')
pbmodel_file = opt_model(
model_file=model_file,
param_file=param_file,
optimize_out_type='protobuf')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
f.read())
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
class TestCase11(unittest.TestCase):
def test_case11(self):
paddle.disable_static()
model = mobilenet_v2()
model2 = ModelCase6()
model3 = ModelCase7()
predictor = TableLatencyPredictor(table_file='SD710')
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 250, 250],
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
model_file, param_file = save_cls_model(
model,
input_shape=[1, 3, 250, 250],
save_dir="./inference_model",
data_type='int8')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='int8')
assert latency > 0
model_file, param_file = save_cls_model(
model2,
input_shape=[1, 3, 16, 16],
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
model_file, param_file = save_det_model(
model3,
input_shape=[1, 255, 14, 14],
save_dir="./inference_model",
data_type='fp32')
latency = predictor.predict(
model_file=model_file, param_file=param_file, data_type='fp32')
assert latency > 0
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册