未验证 提交 c7ec8584 编写于 作者: Z ZichaoGuo 提交者: GitHub

Add latency predictor function and doc (#905)

上级 2550fb4e
import os
import subprocess
import argparse
import paddle
from paddleslim.analysis import TableLatencyPredictor
from paddle.vision.models import mobilenet_v1, mobilenet_v2
opt_tool = 'opt_ubuntu' # use in linux
# opt_tool = 'opt_M1_mac' # use in mac with M1 chip
# opt_tool = 'opt_intel_mac' # use in mac with intel chip
parser = argparse.ArgumentParser(description='latency predictor')
parser.add_argument('--model', type=str, help='which model to test.')
parser.add_argument('--data_type', type=str, default='fp32')
args = parser.parse_args()
if not os.path.exists(opt_tool):
subprocess.call(
f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}',
shell=True)
subprocess.call(f'chmod +x {opt_tool}', shell=True)
def get_latency(model, data_type):
paddle.disable_static()
predictor = TableLatencyPredictor(
f'./{opt_tool}', hardware='845', threads=4, power_mode=3, batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./tmp_model',
data_type=data_type,
task_type='cls')
print('{} latency : {}'.format(data_type, latency))
subprocess.call('rm -rf ./tmp_model', shell=True)
paddle.disable_static()
return latency
if __name__ == '__main__':
if args.model == 'mobilenet_v1':
model = mobilenet_v1()
elif args.model == 'mobilenet_v2':
model = mobilenet_v2()
else:
assert False, f'model should be mobilenet_v1 or mobilenet_v2'
latency = get_latency(model, args.data_type)
if args.model == 'mobilenet_v1' and args.data_type == 'fp32':
assert latency == 41.92806607483133
elif args.model == 'mobilenet_v1' and args.data_type == 'int8':
assert latency == 36.64814722993898
elif args.model == 'mobilenet_v2' and args.data_type == 'fp32':
assert latency == 27.847896889217566
elif args.model == 'mobilenet_v2' and args.data_type == 'int8':
assert latency == 23.967800360138803
else:
assert False, f'model or data_type wrong.'
动态图
==============
.. toctree::
:maxdepth: 1
latency_predictor.md
# LatencyPredictor使用教程
LatencyPredictor主要功能是根据提供的op-latency映射表,预估神经网络网络在特定硬件设备上的实际耗时。它基于Paddle-Lite开发,适用于使用Paddle-Lite部署的模型。映射表以key-value的形式存储,key包含了神经网络模型经过Paddle-Lite图优化后的各种融合op信息,value则代表在特定硬件上的实际耗时。
## 使用方法
1. 下载或自行编译opt优化工具
2. 构建LatencyPredictor
3. 定义模型和预测
### 1. 下载或自行编译opt优化工具
1.1 下载提供的opt工具,可根据运行环境下载适用的opt,目前提供Mac平台([M1芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_M1_mac)[Intel芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_intel_mac))和[Ubuntu](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_ubuntu)平台的opt工具下载。
1.2 也可以自行通过Paddle-Lite源码编译opt工具,具体请参考请参考Paddle-Lite[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html)。编译时需要关闭Paddle-Lite的内存复用功能,即注释掉这[几行代码](https://github.com/PaddlePaddle/Paddle-Lite/blob/d76f45be989d3e01cebf2ac18e047cfd37d52666/lite/core/optimizer/optimizer.cc#L266-L268)
### 2. 构建LatencyPredictor
提供opt工具路径,以及芯片和测试参数信息,LatencyPredictor会根据这些参数自动下载对应的映射表。如下所示,芯片为845芯片,测试线程数threads为4,测速模式power_mode为3,测试batchsize为1.
```
import paddleslim
opt_path = {opt工具路径}
predictor = paddleslim.TableLatencyPredictor(opt_path, hardware='845', threads=4, power_mode=3, batchsize=1)
```
### 3. 定义模型和预测
定义model后可通过predict_latency函数直接预测模型推理耗时,其中,input_shape为输入大小,save_dir为中间pbmodel模型保存路径,data_type可选fp32或int8,task_type=‘cls'表示该模型为分类模型。
```
import paddle
from paddle.vision.models import mobilenet_v1
model = mobilenet_v1()
latency = predictor.predict_latency(model, input_shape=[1,3,224,224], save_dir='./model', data_type='int8', task_type='cls')
print('predicted latency = {}ms'.format(latency))
```
...@@ -14,8 +14,19 @@ ...@@ -14,8 +14,19 @@
from .flops import flops, dygraph_flops from .flops import flops, dygraph_flops
from .model_size import model_size from .model_size import model_size
from .latency import LatencyEvaluator, TableLatencyEvaluator from .latency import LatencyEvaluator, TableLatencyEvaluator
from .latency_predictor import LatencyPredictor, TableLatencyPredictor
from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model
__all__ = [ __all__ = [
'flops', 'dygraph_flops', 'model_size', 'LatencyEvaluator', 'flops',
'TableLatencyEvaluator' 'dygraph_flops',
'model_size',
'LatencyEvaluator',
'TableLatencyEvaluator',
"LatencyPredictor",
"TableLatencyPredictor",
"get_key_from_op",
"save_cls_model",
"save_det_model",
"save_seg_model",
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddleslim
__all__ = [
"get_key_from_op", "save_cls_model", "save_det_model", "save_seg_model"
]
def get_key_from_op(op):
"""Construct key of latency table according to the info of graph's op
"""
param_key = ''
op_type = op.type()
if 'conv2d' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[-1].shape()
weight_shape = op.all_inputs()[-2].shape()
kernel = weight_shape[2]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
dilation = op.attr('dilations')[1]
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}'
elif op_type == 'matmul' or op_type == 'matmul_v2':
X = op.all_inputs()[0].shape()
Y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
int8 = op.attr('enable_int8')
bit_length = op.attr('bit_length')
param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}'
elif 'batch_norm' in op_type or 'layer_norm' in op_type:
out_shape = op.all_outputs()[-1].shape()
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'pool2d' in op_type:
out_shape = op.all_outputs()[0].shape()
data = op.all_inputs()
in_shape = data[-1].shape()
kernel = op.attr('ksize')[1]
stride = op.attr('strides')[1]
padding = op.attr('paddings')[1]
groups = op.attr('groups')
flag_global = 1 if op.attr('global_pooling') else 0
if op.attr('adaptive') and out_shape[-1] == 1:
flag_global = 1
pooling_type = op.attr('pooling_type')
param_key = f'{op_type} in={in_shape} out={out_shape} stride={stride} kernel={kernel}x{kernel} pad={padding} flag_global={flag_global} type={pooling_type})'
elif op_type in [
'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax',
'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape'
] or 'transpose' in op_type or 'interp_v2' in op_type:
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
in_shape = op.all_inputs()[-1].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type:
param_key = f'{op_type}'
elif op_type in ['scale'] or 'reshape' in op_type:
out_shape = op.all_outputs()[0].shape()
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif 'elementwise' in op_type:
out_shape = op.all_outputs()[0].shape()
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
axis = op.attr('axis')
param_key = f'{op_type} X={x} Y={y} axis={axis} out={out_shape}'
elif op_type == 'concat':
data = op.all_inputs()
X = ""
for x in data:
X += f"{x.shape()}"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={X} axis={axis} out={out_shape}'
elif op_type == 'yolo_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
class_num = op.attr('class_num')
param_key = f'{op_type} in={in_shape} out={out_shape} class_num={class_num}'
elif op_type == 'prior_box':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
aspect_ratios = op.attr('aspect_ratios')
max_sizes = op.attr('max_sizes')
min_sizes = op.attr('min_sizes')
param_key = f'{op_type} in={in_shape} out={out_shape} aspect_ratios={aspect_ratios} max_sizes={max_sizes} min_sizes={min_sizes}'
elif op_type == 'slice':
in_shape = op.all_inputs()[-1].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes}'
elif op_type == 'stack':
data = op.all_inputs()
X = "["
for x in data:
X += f"{x.shape()}"
X += "]"
axis = op.attr('axis')
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={X} axis={axis} out={out_shape}'
elif op_type == 'exp':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
decrease_axis = op.attr('decrease_axis')
ends = op.attr('ends')
param_key = f'{op_type} in={in_shape} out={out_shape} axes={axes} decrease_axis={decrease_axis} ends={ends}'
elif op_type in ['multiclass_nms3', 'matrix_nms']:
boxs = op.all_inputs()[0].shape()
scores = op.all_inputs()[-1].shape()
keep_top_k = op.attr('keep_top_k')
nms_top_k = op.attr('nms_top_k')
param_key = f'{op_type} boxs={boxs} scores={scores} keep_top_k={keep_top_k} nms_top_k={nms_top_k}'
elif op_type == 'dropout':
in_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape}'
elif op_type == 'fc':
in_shape = op.all_inputs()[-2].shape()
weight_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape}'
elif op_type == 'shuffle_channel':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
group = op.attr('group')
param_key = f'{op_type} in={in_shape} group={group} out={out_shape}'
elif op_type == 'split':
in_shape = op.all_inputs()[-1].shape()
axis = op.attr('axis')
sections = op.attr('sections')
param_key = f'{op_type} in={in_shape} axis={axis} sections={sections}'
elif op_type in ['unsqueeze2', 'squeeze2']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
axes = op.attr('axes')
param_key = f'{op_type} in={in_shape} axes={axes} out={out_shape}'
elif op_type == 'flatten_contiguous_range':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
start_axis = op.attr('start_axis')
stop_axis = op.attr(' stop_axis')
param_key = f'{op_type} in={in_shape} start_axis={start_axis} stop_axis={stop_axis} out={out_shape}'
elif op_type == 'sum':
in_shape1 = op.all_inputs()[0].shape()
in_shape2 = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}'
elif op_type in ['calib', 'floor']:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_inputs()[0].shape()
param_key = f'{op_type} in={in_shape} out={out_shape}'
elif op_type == 'uniform_random':
shape = op.attr('shape')
param_key = f'{op_type} shape={shape}'
elif op_type == 'greater_equal':
x = op.all_inputs()[0].shape()
y = op.all_inputs()[1].shape()
out_shape = op.all_outputs()[0].shape()
param_key = f'{op_type} X={x} Y={y} out={out_shape}'
elif op_type == 'reduce_mean':
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
dim = op.attr('dim')
param_key = f'{op_type} in={in_shape} out={out_shape} dim={dim}'
elif 'pad3d' in op_type:
in_shape = op.all_inputs()[-1].shape()
out_shape = op.all_outputs()[0].shape()
paddings = op.attr('paddings')
param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}'
elif op_type in ['feed', 'fetch']:
pass
else:
print(op)
print(op._op)
raise KeyError(f'The "{op_type}" has never seen.')
return param_key
def sample_generator(input_shape, batch_num):
def __reader__():
for i in range(batch_num):
image = np.random.random(input_shape).astype('float32')
yield image
return __reader__
def save_cls_model(model, input_shape, save_dir, data_type):
paddle.jit.save(
model,
path=os.path.join(save_dir, 'fp32model'),
input_spec=[
paddle.static.InputSpec(
shape=input_shape, dtype='float32', name='x'),
])
model_file = os.path.join(save_dir, 'fp32model.pdmodel')
param_file = os.path.join(save_dir, 'fp32model.pdiparams')
if data_type == 'int8':
paddle.enable_static()
exe = paddle.fluid.Executor(paddle.fluid.CPUPlace())
save_dir = os.path.dirname(model_file)
quantize_model_path = os.path.join(save_dir, 'int8model')
if not os.path.exists(quantize_model_path):
os.makedirs(quantize_model_path)
paddleslim.quant.quant_post_static(
executor=exe,
model_dir=save_dir,
quantize_model_path=quantize_model_path,
sample_generator=sample_generator(input_shape, 1),
model_filename=model_file.split('/')[-1],
params_filename=param_file.split('/')[-1],
batch_size=input_shape[0],
batch_nums=1,
weight_bits=8,
activation_bits=8)
model_file = os.path.join(quantize_model_path, '__model__')
param_file = os.path.join(quantize_model_path, '__params__')
return model_file, param_file
def save_det_model(model,
input_shape,
save_dir,
data_type,
det_multi_input=False):
model.eval()
if det_multi_input:
input_spec = [{
"image": paddle.static.InputSpec(
shape=input_shape, name='image'),
"im_shape": paddle.static.InputSpec(
shape=[input_shape[0], 2], name='im_shape'),
"scale_factor": paddle.static.InputSpec(
shape=[input_shape[0], 2], name='scale_factor')
}]
data = {
"image": paddle.randn(
shape=input_shape, dtype='float32', name='image'),
"im_shape": paddle.randn(
shape=[input_shape[0], 2], dtype='float32', name='image'),
"scale_factor": paddle.ones(
shape=[input_shape[0], 2], dtype='float32', name='image')
}
else:
input_spec = [{
"image": paddle.static.InputSpec(
shape=input_shape, name='image'),
}]
data = {
"image": paddle.randn(
shape=input_shape, dtype='float32', name='image'),
}
if data_type == 'fp32':
static_model = paddle.jit.to_static(model, input_spec=input_spec)
paddle.jit.save(
static_model,
path=os.path.join(save_dir, 'fp32model'),
input_spec=input_spec)
model_file = os.path.join(save_dir, 'fp32model.pdmodel')
param_file = os.path.join(save_dir, 'fp32model.pdiparams')
else:
ptq = paddleslim.dygraph.quant.PTQ()
quant_model = ptq.quantize(model, fuse=True, fuse_list=None)
quant_model(data)
quantize_model_path = os.path.join(save_dir, 'int8model')
if not os.path.exists(quantize_model_path):
os.makedirs(quantize_model_path)
ptq.save_quantized_model(quant_model,
os.path.join(quantize_model_path, 'int8model'),
input_spec)
model_file = os.path.join(quantize_model_path, 'int8model.pdmodel')
param_file = os.path.join(quantize_model_path, 'int8model.pdiparams')
return model_file, param_file
def save_seg_model(model, input_shape, save_dir, data_type):
if data_type == 'fp32':
paddle.jit.save(
model,
path=os.path.join(save_dir, 'fp32model'),
input_spec=[
paddle.static.InputSpec(
shape=input_shape, dtype='float32', name='x'),
])
model_file = os.path.join(save_dir, 'fp32model.pdmodel')
param_file = os.path.join(save_dir, 'fp32model.pdiparams')
else:
save_dir = os.path.join(save_dir, 'int8model')
quant_config = {
'weight_preprocess_type': None,
'activation_preprocess_type': None,
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8,
'activation_bits': 8,
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
quantizer = paddleslim.QAT(config=quant_config)
quantizer.quantize(model)
quantizer.save_quantized_model(
model,
save_dir,
input_spec=[
paddle.static.InputSpec(
shape=input_shape, dtype='float32')
])
model_file = f'{save_dir}.pdmodel'
param_file = f'{save_dir}.pdiparams'
return model_file, param_file
"""Define latency predictor that predict the latency of model on devices.
"""
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import time
import subprocess
from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model
import paddle
import paddleslim
__all__ = ["LatencyPredictor", "TableLatencyPredictor"]
class LatencyPredictor(object):
"""Base class of latency predictor.
"""
def predict_latency(self, model):
"""Get latency of model. It is an abstract method.
Args:
model: The model to be evaluated.
Returns:
latency(float): The latency of given model on current evaluator.
"""
raise NotImplementedError('Abstract method.')
def _get_key_info_from_graph(self, graph):
graph_keys = []
for op in graph.ops():
param_key = get_key_from_op(op)
graph_keys.append(param_key)
return graph_keys
class TableLatencyPredictor(LatencyPredictor):
"""The preditor used to get pbmodel's latency on some devices and infer engines.
Args:
table_file(str): The path of file that records the devices latency of operators.
opt_path(str): The path of opt tool to convert a paddle model to an optimized pbmodel that fuses operators.
"""
def __init__(self,
opt_path,
hardware='845',
threads=4,
power_mode=3,
batchsize=1):
self.table_file = f'{hardware}_threads_{threads}_power_mode_{power_mode}_batchsize_{batchsize}.pkl'
self.opt_path = opt_path
self.table_dict = {}
self._read_table()
self.det_multi_input = False
def _read_table(self):
if not os.path.exists(self.table_file):
subprocess.call(
f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{self.table_file}',
shell=True)
assert os.path.exists(
self.table_file), f'{self.table_file} is not existed.'
with open(self.table_file, 'rb') as f:
self.table_dict = pickle.load(f)
print('Successfully load {}'.format(self.table_file))
def set_det_multi_input(self, det_multi_input):
"""If a detection model has multiple input, the self.det_multi_input should be True. Default: False.
"""
self.det_multi_input = det_multi_input
def opt_model(self, model, input_shape, save_dir, data_type, task_type):
"""Convert the model graph to an optimized pbmodel by using opt tool.
Args:
model: The input model graph.
input_shape(list): The input shape of model.
save_dir: Where to save the pbmodel.
data_type: Data type, fp32 or int8.
task_type: Task type, cls, det or seg, different task models need to use different quantization strategies.
Returns:
pbmodel_file: The path of optimized pbmodel.
"""
if task_type == 'cls':
model_file, param_file = save_cls_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type)
elif task_type == 'det':
model_file, param_file = save_det_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type,
det_multi_input=self.det_multi_input)
elif task_type == 'seg':
model_file, param_file = save_seg_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type)
else:
assert task_type in ['cls', 'det', 'seg'
], f'task_type must be one of [cls, det, seg]'
pb_model = os.path.join(save_dir, f'{data_type}pbmodel')
if not os.path.exists(pb_model):
os.makedirs(pb_model)
cmd = f'{self.opt_path} --model_file={model_file} --param_file={param_file} --optimize_out_type=protobuf --optimize_out={pb_model} --valid_targets=arm'
print(f'commands:{cmd}')
m = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out = m.communicate()
print(out, 'opt done!')
pbmodel_file = os.path.join(pb_model, 'model')
return pbmodel_file
def predict_latency(self,
model,
input_shape=[1, 3, 224, 224],
save_dir='',
data_type='int8',
task_type='cls'):
"""predict the latency of the model
Args:
model: The input model graph.
input_shape(list): The input shape of model. Default: [1,3,224,224].
save_dir: Where to save the pbmodel.
data_type: Data type, fp32 or int8. Default : int8
task_type: Task type, cls, det or seg, different task models need to use different quantization strategies. Default: cls.
Returns:
latency(float): The latency of the pbmodel.
"""
assert data_type in ['fp32', 'int8'
], f'data_type must be one of [fp32, int8]'
assert task_type in ['cls', 'det', 'seg'
], f'task_type must be one of [cls, det, seg]'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
pbmodel_file = self.opt_model(
model=model,
input_shape=input_shape,
save_dir=save_dir,
data_type=data_type,
task_type=task_type)
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
program = paddle.fluid.proto.framework_pb2.ProgramDesc.FromString(
program_desc_str)
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
graph = paddleslim.core.GraphWrapper(fluid_program)
latency = 0.0
for op in graph.ops():
param_key = get_key_from_op(op)
if param_key != '':
assert param_key in self.table_dict, f'{param_key} is not in the tabel.'
latency += self.table_dict[param_key]
return latency
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys, os
sys.path.append("../")
import unittest
import paddle
import paddleslim
from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor
from paddle.vision.models import mobilenet_v1, mobilenet_v2
from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm
import subprocess
opt_tool = 'opt_ubuntu' # use in linux
# opt_tool = 'opt_M1_mac' # use in mac with M1 chip
# opt_tool = 'opt_intel_mac' # use in mac with intel chip
if not os.path.exists(opt_tool):
subprocess.call(
f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}',
shell=True)
subprocess.call(f'chmod +x {opt_tool}', shell=True)
def channel_shuffle(x, groups):
batch_size, num_channels, height, width = x.shape[0:4]
channels_per_group = num_channels // groups
x = paddle.reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width])
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width])
return x
class ModelCase1(paddle.nn.Layer):
def __init__(self):
super(ModelCase1, self).__init__()
self.conv1 = Conv2D(58, 58, 1)
self.conv2 = Conv2D(58, 58, 1)
def forward(self, inputs):
x1, x2 = paddle.split(
inputs,
num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2],
axis=1)
x1 = self.conv1(x1)
x2 = self.conv2(x2)
out = paddle.concat([x1, x2], axis=1)
return channel_shuffle(out, 2)
class ModelCase2(paddle.nn.Layer):
def __init__(self):
super(ModelCase2, self).__init__()
self.conv1 = Conv2D(3, 24, 3, stride=2, padding=1)
def forward(self, inputs):
image = inputs['image']
return self.conv1(image)
class ModelCase3(paddle.nn.Layer):
def __init__(self):
super(ModelCase3, self).__init__()
self.conv1 = Conv2D(3, 24, 3, stride=2, padding=1)
def forward(self, inputs):
image = inputs['image']
im_shape = inputs['im_shape']
scale_factor = inputs['scale_factor']
return self.conv1(image), im_shape, scale_factor
class ModelCase4(paddle.nn.Layer):
def __init__(self):
super(ModelCase4, self).__init__()
self.bn1 = BatchNorm2D(3)
self.ln1 = LayerNorm([3 * 16 * 16])
self.relu1 = ReLU()
self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16)
def forward(self, inputs):
x = self.bn1(inputs)
x = paddle.reshape(x, [1, 3 * 16 * 16])
x = self.ln1(x)
x = self.fc1(x)
x = paddle.fluid.layers.unsqueeze(input=x, axes=[2])
x = self.relu1(x)
y = paddle.fluid.layers.fill_constant(
x.shape, dtype=paddle.float32, value=1)
x = paddle.stack([x, y], axis=3)
x = paddle.slice(x, axes=[0], starts=[0], ends=[1])
x = paddle.exp(x)
y += paddle.fluid.layers.uniform_random(y.shape)
y = paddle.fluid.layers.reduce_mean(y, dim=1, keep_dim=True)
return x + y
class ModelCase5(paddle.nn.Layer):
def __init__(self):
super(ModelCase5, self).__init__()
self.bn1 = BatchNorm2D(255)
def forward(self, inputs):
image = inputs['image']
image = self.bn1(image)
img_size = paddle.fluid.data(
name='img_size', shape=[None, 2], dtype='int64')
anchors = [10, 13, 16, 30, 33, 23]
boxes, scores = paddle.fluid.layers.yolo_box(
x=image,
img_size=img_size,
class_num=80,
anchors=anchors,
conf_thresh=0.01,
downsample_ratio=32)
out = paddle.fluid.layers.matrix_nms(
bboxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
post_threshold=0.1,
nms_top_k=400,
keep_top_k=200,
normalized=False)
box, var = paddle.fluid.layers.prior_box(
input=image, image=image, min_sizes=[2.], clip=True, flip=True)
return boxes, scores, box, var, out
class TestCase1(unittest.TestCase):
def test_case1(self):
paddle.disable_static()
model = mobilenet_v1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
assert latency > 0
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='cls')
assert latency > 0
class TestCase2(unittest.TestCase):
def test_case2(self):
paddle.disable_static()
model = mobilenet_v2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
assert latency > 0
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='cls')
assert latency > 0
class TestCase3(unittest.TestCase):
def test_case3(self):
paddle.disable_static()
model = mobilenet_v2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pred = LatencyPredictor()
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='cls')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = pred._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
class TestCase4(unittest.TestCase):
def test_case4(self):
paddle.disable_static()
model = ModelCase1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 116, 28, 28],
save_dir='./model',
data_type='fp32',
task_type='cls')
assert latency > 0
class TestCase5(unittest.TestCase):
def test_case5(self):
paddle.disable_static()
model = mobilenet_v1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='seg')
assert latency > 0
class TestCase6(unittest.TestCase):
def test_case6(self):
paddle.disable_static()
model = ModelCase2()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='int8',
task_type='det')
assert os.path.exists(pbmodel_file)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='det')
assert latency > 0
class TestCase7(unittest.TestCase):
def test_case7(self):
paddle.disable_static()
model = ModelCase3()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
predictor.set_det_multi_input(det_multi_input=True)
latency = predictor.predict_latency(
model,
input_shape=[1, 3, 224, 224],
save_dir='./model',
data_type='fp32',
task_type='det')
assert latency > 0
class TestCase8(unittest.TestCase):
def test_case8(self):
paddle.disable_static()
model = ModelCase4()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 3, 16, 16],
save_dir='./model',
data_type='int8',
task_type='cls')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
class TestCase9(unittest.TestCase):
def test_case9(self):
paddle.disable_static()
model = ModelCase5()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 255, 13, 13],
save_dir='./model',
data_type='fp32',
task_type='det')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
class TestCase10(unittest.TestCase):
def test_case10(self):
paddle.disable_static()
model = ModelCase1()
predictor = TableLatencyPredictor(
f'./{opt_tool}',
hardware='845',
threads=4,
power_mode=3,
batchsize=1)
pbmodel_file = predictor.opt_model(
model,
input_shape=[1, 116, 28, 28],
save_dir='./model',
data_type='int8',
task_type='seg')
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
program_desc_str = f.read()
fluid_program = paddle.fluid.framework.Program.parse_from_string(
program_desc_str)
graph = paddleslim.core.GraphWrapper(fluid_program)
graph_keys = predictor._get_key_info_from_graph(graph=graph)
assert len(graph_keys) > 0
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册