From c7ec8584d342172e25ba738c8994139398482204 Mon Sep 17 00:00:00 2001 From: ZichaoGuo <757855223@qq.com> Date: Tue, 19 Oct 2021 13:58:24 +0800 Subject: [PATCH] Add latency predictor function and doc (#905) --- demo/analysis/latency_predictor.py | 63 +++ .../tutorials/analysis/dygraph/index.rst | 8 + .../analysis/dygraph/latency_predictor.md | 35 ++ paddleslim/analysis/__init__.py | 15 +- paddleslim/analysis/_utils.py | 401 ++++++++++++++++++ paddleslim/analysis/latency_predictor.py | 191 +++++++++ tests/test_latency_predictor.py | 388 +++++++++++++++++ 7 files changed, 1099 insertions(+), 2 deletions(-) create mode 100644 demo/analysis/latency_predictor.py create mode 100644 docs/zh_cn/tutorials/analysis/dygraph/index.rst create mode 100644 docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md create mode 100644 paddleslim/analysis/_utils.py create mode 100644 paddleslim/analysis/latency_predictor.py create mode 100644 tests/test_latency_predictor.py diff --git a/demo/analysis/latency_predictor.py b/demo/analysis/latency_predictor.py new file mode 100644 index 00000000..fbcb3d7d --- /dev/null +++ b/demo/analysis/latency_predictor.py @@ -0,0 +1,63 @@ +import os +import subprocess +import argparse + +import paddle +from paddleslim.analysis import TableLatencyPredictor + +from paddle.vision.models import mobilenet_v1, mobilenet_v2 + +opt_tool = 'opt_ubuntu' # use in linux +# opt_tool = 'opt_M1_mac' # use in mac with M1 chip +# opt_tool = 'opt_intel_mac' # use in mac with intel chip + +parser = argparse.ArgumentParser(description='latency predictor') +parser.add_argument('--model', type=str, help='which model to test.') +parser.add_argument('--data_type', type=str, default='fp32') + +args = parser.parse_args() + +if not os.path.exists(opt_tool): + subprocess.call( + f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}', + shell=True) + subprocess.call(f'chmod +x {opt_tool}', shell=True) + + +def get_latency(model, data_type): + paddle.disable_static() + predictor = TableLatencyPredictor( + f'./{opt_tool}', hardware='845', threads=4, power_mode=3, batchsize=1) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./tmp_model', + data_type=data_type, + task_type='cls') + print('{} latency : {}'.format(data_type, latency)) + + subprocess.call('rm -rf ./tmp_model', shell=True) + paddle.disable_static() + return latency + + +if __name__ == '__main__': + if args.model == 'mobilenet_v1': + model = mobilenet_v1() + elif args.model == 'mobilenet_v2': + model = mobilenet_v2() + else: + assert False, f'model should be mobilenet_v1 or mobilenet_v2' + + latency = get_latency(model, args.data_type) + + if args.model == 'mobilenet_v1' and args.data_type == 'fp32': + assert latency == 41.92806607483133 + elif args.model == 'mobilenet_v1' and args.data_type == 'int8': + assert latency == 36.64814722993898 + elif args.model == 'mobilenet_v2' and args.data_type == 'fp32': + assert latency == 27.847896889217566 + elif args.model == 'mobilenet_v2' and args.data_type == 'int8': + assert latency == 23.967800360138803 + else: + assert False, f'model or data_type wrong.' diff --git a/docs/zh_cn/tutorials/analysis/dygraph/index.rst b/docs/zh_cn/tutorials/analysis/dygraph/index.rst new file mode 100644 index 00000000..f0220585 --- /dev/null +++ b/docs/zh_cn/tutorials/analysis/dygraph/index.rst @@ -0,0 +1,8 @@ + +动态图 +============== + +.. toctree:: + :maxdepth: 1 + + latency_predictor.md diff --git a/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md b/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md new file mode 100644 index 00000000..4c6aef40 --- /dev/null +++ b/docs/zh_cn/tutorials/analysis/dygraph/latency_predictor.md @@ -0,0 +1,35 @@ +# LatencyPredictor使用教程 + +LatencyPredictor主要功能是根据提供的op-latency映射表,预估神经网络网络在特定硬件设备上的实际耗时。它基于Paddle-Lite开发,适用于使用Paddle-Lite部署的模型。映射表以key-value的形式存储,key包含了神经网络模型经过Paddle-Lite图优化后的各种融合op信息,value则代表在特定硬件上的实际耗时。 + +## 使用方法 + +1. 下载或自行编译opt优化工具 +2. 构建LatencyPredictor +3. 定义模型和预测 + +### 1. 下载或自行编译opt优化工具 +1.1 下载提供的opt工具,可根据运行环境下载适用的opt,目前提供Mac平台([M1芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_M1_mac),[Intel芯片](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_intel_mac))和[Ubuntu](https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/opt_ubuntu)平台的opt工具下载。 +1.2 也可以自行通过Paddle-Lite源码编译opt工具,具体请参考请参考Paddle-Lite[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html)。编译时需要关闭Paddle-Lite的内存复用功能,即注释掉这[几行代码](https://github.com/PaddlePaddle/Paddle-Lite/blob/d76f45be989d3e01cebf2ac18e047cfd37d52666/lite/core/optimizer/optimizer.cc#L266-L268)。 + +### 2. 构建LatencyPredictor + +提供opt工具路径,以及芯片和测试参数信息,LatencyPredictor会根据这些参数自动下载对应的映射表。如下所示,芯片为845芯片,测试线程数threads为4,测速模式power_mode为3,测试batchsize为1. +``` +import paddleslim + +opt_path = {opt工具路径} +predictor = paddleslim.TableLatencyPredictor(opt_path, hardware='845', threads=4, power_mode=3, batchsize=1) +``` + +### 3. 定义模型和预测 + +定义model后可通过predict_latency函数直接预测模型推理耗时,其中,input_shape为输入大小,save_dir为中间pbmodel模型保存路径,data_type可选fp32或int8,task_type=‘cls'表示该模型为分类模型。 +``` +import paddle +from paddle.vision.models import mobilenet_v1 + +model = mobilenet_v1() +latency = predictor.predict_latency(model, input_shape=[1,3,224,224], save_dir='./model', data_type='int8', task_type='cls') +print('predicted latency = {}ms'.format(latency)) +``` diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index a23da8e9..32e97890 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -14,8 +14,19 @@ from .flops import flops, dygraph_flops from .model_size import model_size from .latency import LatencyEvaluator, TableLatencyEvaluator +from .latency_predictor import LatencyPredictor, TableLatencyPredictor +from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model __all__ = [ - 'flops', 'dygraph_flops', 'model_size', 'LatencyEvaluator', - 'TableLatencyEvaluator' + 'flops', + 'dygraph_flops', + 'model_size', + 'LatencyEvaluator', + 'TableLatencyEvaluator', + "LatencyPredictor", + "TableLatencyPredictor", + "get_key_from_op", + "save_cls_model", + "save_det_model", + "save_seg_model", ] diff --git a/paddleslim/analysis/_utils.py b/paddleslim/analysis/_utils.py new file mode 100644 index 00000000..e8b21cf3 --- /dev/null +++ b/paddleslim/analysis/_utils.py @@ -0,0 +1,401 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +import paddle +import paddleslim +__all__ = [ + "get_key_from_op", "save_cls_model", "save_det_model", "save_seg_model" +] + + +def get_key_from_op(op): + """Construct key of latency table according to the info of graph's op + """ + param_key = '' + op_type = op.type() + + if 'conv2d' in op_type: + out_shape = op.all_outputs()[0].shape() + in_shape = op.all_inputs()[-1].shape() + weight_shape = op.all_inputs()[-2].shape() + kernel = weight_shape[2] + stride = op.attr('strides')[1] + padding = op.attr('paddings')[1] + groups = op.attr('groups') + dilation = op.attr('dilations')[1] + int8 = op.attr('enable_int8') + bit_length = op.attr('bit_length') + + param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape} pad={padding} stride={stride} group={groups} dilation={dilation} quant={int8} bit_length={bit_length}' + + elif op_type == 'matmul' or op_type == 'matmul_v2': + X = op.all_inputs()[0].shape() + Y = op.all_inputs()[1].shape() + out_shape = op.all_outputs()[0].shape() + int8 = op.attr('enable_int8') + bit_length = op.attr('bit_length') + + param_key = f'{op_type} X={X} Y={Y} out={out_shape} quant={int8} bit_length={bit_length}' + + elif 'batch_norm' in op_type or 'layer_norm' in op_type: + out_shape = op.all_outputs()[-1].shape() + in_shape = op.all_inputs()[-1].shape() + + param_key = f'{op_type} in={in_shape} out={out_shape}' + + elif 'pool2d' in op_type: + out_shape = op.all_outputs()[0].shape() + data = op.all_inputs() + in_shape = data[-1].shape() + kernel = op.attr('ksize')[1] + stride = op.attr('strides')[1] + padding = op.attr('paddings')[1] + groups = op.attr('groups') + flag_global = 1 if op.attr('global_pooling') else 0 + if op.attr('adaptive') and out_shape[-1] == 1: + flag_global = 1 + pooling_type = op.attr('pooling_type') + + param_key = f'{op_type} in={in_shape} out={out_shape} stride={stride} kernel={kernel}x{kernel} pad={padding} flag_global={flag_global} type={pooling_type})' + + elif op_type in [ + 'hard_swish', 'relu', 'leaky_relu', 'tanh', 'swish', 'softmax', + 'hard_sigmoid', 'sigmoid', 'gelu', 'clip', 'shape' + ] or 'transpose' in op_type or 'interp_v2' in op_type: + in_shape = op.all_inputs()[-1].shape() + + param_key = f'{op_type} in={in_shape}' + in_shape = op.all_inputs()[-1].shape() + + param_key = f'{op_type} in={in_shape}' + + elif op_type in ['fill_constant', 'range', 'cast'] or 'expand' in op_type: + + param_key = f'{op_type}' + + elif op_type in ['scale'] or 'reshape' in op_type: + out_shape = op.all_outputs()[0].shape() + in_shape = op.all_inputs()[0].shape() + + param_key = f'{op_type} in={in_shape} out={out_shape}' + + elif 'elementwise' in op_type: + out_shape = op.all_outputs()[0].shape() + x = op.all_inputs()[0].shape() + y = op.all_inputs()[1].shape() + axis = op.attr('axis') + + param_key = f'{op_type} X={x} Y={y} axis={axis} out={out_shape}' + + elif op_type == 'concat': + data = op.all_inputs() + X = "" + for x in data: + X += f"{x.shape()}" + axis = op.attr('axis') + out_shape = op.all_outputs()[0].shape() + + param_key = f'{op_type} in={X} axis={axis} out={out_shape}' + + elif op_type == 'yolo_box': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + class_num = op.attr('class_num') + + param_key = f'{op_type} in={in_shape} out={out_shape} class_num={class_num}' + + elif op_type == 'prior_box': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + aspect_ratios = op.attr('aspect_ratios') + max_sizes = op.attr('max_sizes') + min_sizes = op.attr('min_sizes') + + param_key = f'{op_type} in={in_shape} out={out_shape} aspect_ratios={aspect_ratios} max_sizes={max_sizes} min_sizes={min_sizes}' + + elif op_type == 'slice': + in_shape = op.all_inputs()[-1].shape() + axes = op.attr('axes') + + param_key = f'{op_type} in={in_shape} axes={axes}' + + elif op_type == 'stack': + data = op.all_inputs() + X = "[" + for x in data: + X += f"{x.shape()}" + X += "]" + axis = op.attr('axis') + out_shape = op.all_outputs()[0].shape() + + param_key = f'{op_type} X={X} axis={axis} out={out_shape}' + + elif op_type == 'exp': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + axes = op.attr('axes') + decrease_axis = op.attr('decrease_axis') + ends = op.attr('ends') + + param_key = f'{op_type} in={in_shape} out={out_shape} axes={axes} decrease_axis={decrease_axis} ends={ends}' + + elif op_type in ['multiclass_nms3', 'matrix_nms']: + boxs = op.all_inputs()[0].shape() + scores = op.all_inputs()[-1].shape() + keep_top_k = op.attr('keep_top_k') + nms_top_k = op.attr('nms_top_k') + + param_key = f'{op_type} boxs={boxs} scores={scores} keep_top_k={keep_top_k} nms_top_k={nms_top_k}' + + elif op_type == 'dropout': + in_shape = op.all_inputs()[0].shape() + + param_key = f'{op_type} in={in_shape}' + + elif op_type == 'fc': + in_shape = op.all_inputs()[-2].shape() + weight_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + + param_key = f'{op_type} in={in_shape} weight={weight_shape} out={out_shape}' + + elif op_type == 'shuffle_channel': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + group = op.attr('group') + + param_key = f'{op_type} in={in_shape} group={group} out={out_shape}' + + elif op_type == 'split': + in_shape = op.all_inputs()[-1].shape() + axis = op.attr('axis') + sections = op.attr('sections') + + param_key = f'{op_type} in={in_shape} axis={axis} sections={sections}' + + elif op_type in ['unsqueeze2', 'squeeze2']: + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + axes = op.attr('axes') + + param_key = f'{op_type} in={in_shape} axes={axes} out={out_shape}' + + elif op_type == 'flatten_contiguous_range': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + start_axis = op.attr('start_axis') + stop_axis = op.attr(' stop_axis') + + param_key = f'{op_type} in={in_shape} start_axis={start_axis} stop_axis={stop_axis} out={out_shape}' + + elif op_type == 'sum': + in_shape1 = op.all_inputs()[0].shape() + in_shape2 = op.all_inputs()[1].shape() + out_shape = op.all_outputs()[0].shape() + + param_key = f'{op_type} in={in_shape1} in={in_shape2} out={out_shape}' + + elif op_type in ['calib', 'floor']: + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_inputs()[0].shape() + + param_key = f'{op_type} in={in_shape} out={out_shape}' + + elif op_type == 'uniform_random': + shape = op.attr('shape') + + param_key = f'{op_type} shape={shape}' + + elif op_type == 'greater_equal': + x = op.all_inputs()[0].shape() + y = op.all_inputs()[1].shape() + out_shape = op.all_outputs()[0].shape() + + param_key = f'{op_type} X={x} Y={y} out={out_shape}' + + elif op_type == 'reduce_mean': + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + dim = op.attr('dim') + + param_key = f'{op_type} in={in_shape} out={out_shape} dim={dim}' + + elif 'pad3d' in op_type: + in_shape = op.all_inputs()[-1].shape() + out_shape = op.all_outputs()[0].shape() + paddings = op.attr('paddings') + + param_key = f'{op_type} in={in_shape} out={out_shape} paddings={paddings}' + + elif op_type in ['feed', 'fetch']: + pass + + else: + print(op) + print(op._op) + raise KeyError(f'The "{op_type}" has never seen.') + + return param_key + + +def sample_generator(input_shape, batch_num): + def __reader__(): + for i in range(batch_num): + image = np.random.random(input_shape).astype('float32') + yield image + + return __reader__ + + +def save_cls_model(model, input_shape, save_dir, data_type): + paddle.jit.save( + model, + path=os.path.join(save_dir, 'fp32model'), + input_spec=[ + paddle.static.InputSpec( + shape=input_shape, dtype='float32', name='x'), + ]) + model_file = os.path.join(save_dir, 'fp32model.pdmodel') + param_file = os.path.join(save_dir, 'fp32model.pdiparams') + + if data_type == 'int8': + paddle.enable_static() + exe = paddle.fluid.Executor(paddle.fluid.CPUPlace()) + save_dir = os.path.dirname(model_file) + + quantize_model_path = os.path.join(save_dir, 'int8model') + if not os.path.exists(quantize_model_path): + os.makedirs(quantize_model_path) + + paddleslim.quant.quant_post_static( + executor=exe, + model_dir=save_dir, + quantize_model_path=quantize_model_path, + sample_generator=sample_generator(input_shape, 1), + model_filename=model_file.split('/')[-1], + params_filename=param_file.split('/')[-1], + batch_size=input_shape[0], + batch_nums=1, + weight_bits=8, + activation_bits=8) + + model_file = os.path.join(quantize_model_path, '__model__') + param_file = os.path.join(quantize_model_path, '__params__') + + return model_file, param_file + + +def save_det_model(model, + input_shape, + save_dir, + data_type, + det_multi_input=False): + model.eval() + if det_multi_input: + input_spec = [{ + "image": paddle.static.InputSpec( + shape=input_shape, name='image'), + "im_shape": paddle.static.InputSpec( + shape=[input_shape[0], 2], name='im_shape'), + "scale_factor": paddle.static.InputSpec( + shape=[input_shape[0], 2], name='scale_factor') + }] + data = { + "image": paddle.randn( + shape=input_shape, dtype='float32', name='image'), + "im_shape": paddle.randn( + shape=[input_shape[0], 2], dtype='float32', name='image'), + "scale_factor": paddle.ones( + shape=[input_shape[0], 2], dtype='float32', name='image') + } + else: + input_spec = [{ + "image": paddle.static.InputSpec( + shape=input_shape, name='image'), + }] + data = { + "image": paddle.randn( + shape=input_shape, dtype='float32', name='image'), + } + + if data_type == 'fp32': + static_model = paddle.jit.to_static(model, input_spec=input_spec) + paddle.jit.save( + static_model, + path=os.path.join(save_dir, 'fp32model'), + input_spec=input_spec) + model_file = os.path.join(save_dir, 'fp32model.pdmodel') + param_file = os.path.join(save_dir, 'fp32model.pdiparams') + + else: + ptq = paddleslim.dygraph.quant.PTQ() + quant_model = ptq.quantize(model, fuse=True, fuse_list=None) + quant_model(data) + quantize_model_path = os.path.join(save_dir, 'int8model') + if not os.path.exists(quantize_model_path): + os.makedirs(quantize_model_path) + + ptq.save_quantized_model(quant_model, + os.path.join(quantize_model_path, 'int8model'), + input_spec) + + model_file = os.path.join(quantize_model_path, 'int8model.pdmodel') + param_file = os.path.join(quantize_model_path, 'int8model.pdiparams') + + return model_file, param_file + + +def save_seg_model(model, input_shape, save_dir, data_type): + if data_type == 'fp32': + paddle.jit.save( + model, + path=os.path.join(save_dir, 'fp32model'), + input_spec=[ + paddle.static.InputSpec( + shape=input_shape, dtype='float32', name='x'), + ]) + model_file = os.path.join(save_dir, 'fp32model.pdmodel') + param_file = os.path.join(save_dir, 'fp32model.pdiparams') + + else: + save_dir = os.path.join(save_dir, 'int8model') + quant_config = { + 'weight_preprocess_type': None, + 'activation_preprocess_type': None, + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, + 'activation_bits': 8, + 'dtype': 'int8', + 'window_size': 10000, + 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear'], + } + quantizer = paddleslim.QAT(config=quant_config) + quantizer.quantize(model) + quantizer.save_quantized_model( + model, + save_dir, + input_spec=[ + paddle.static.InputSpec( + shape=input_shape, dtype='float32') + ]) + + model_file = f'{save_dir}.pdmodel' + param_file = f'{save_dir}.pdiparams' + + return model_file, param_file diff --git a/paddleslim/analysis/latency_predictor.py b/paddleslim/analysis/latency_predictor.py new file mode 100644 index 00000000..3bd41691 --- /dev/null +++ b/paddleslim/analysis/latency_predictor.py @@ -0,0 +1,191 @@ +"""Define latency predictor that predict the latency of model on devices. +""" +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import time +import subprocess + +from ._utils import get_key_from_op, save_cls_model, save_det_model, save_seg_model +import paddle +import paddleslim +__all__ = ["LatencyPredictor", "TableLatencyPredictor"] + + +class LatencyPredictor(object): + """Base class of latency predictor. + """ + + def predict_latency(self, model): + """Get latency of model. It is an abstract method. + + Args: + model: The model to be evaluated. + + Returns: + latency(float): The latency of given model on current evaluator. + """ + raise NotImplementedError('Abstract method.') + + def _get_key_info_from_graph(self, graph): + graph_keys = [] + for op in graph.ops(): + param_key = get_key_from_op(op) + graph_keys.append(param_key) + return graph_keys + + +class TableLatencyPredictor(LatencyPredictor): + """The preditor used to get pbmodel's latency on some devices and infer engines. + + Args: + table_file(str): The path of file that records the devices latency of operators. + opt_path(str): The path of opt tool to convert a paddle model to an optimized pbmodel that fuses operators. + """ + + def __init__(self, + opt_path, + hardware='845', + threads=4, + power_mode=3, + batchsize=1): + self.table_file = f'{hardware}_threads_{threads}_power_mode_{power_mode}_batchsize_{batchsize}.pkl' + self.opt_path = opt_path + self.table_dict = {} + self._read_table() + self.det_multi_input = False + + def _read_table(self): + if not os.path.exists(self.table_file): + subprocess.call( + f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{self.table_file}', + shell=True) + + assert os.path.exists( + self.table_file), f'{self.table_file} is not existed.' + with open(self.table_file, 'rb') as f: + self.table_dict = pickle.load(f) + print('Successfully load {}'.format(self.table_file)) + + def set_det_multi_input(self, det_multi_input): + """If a detection model has multiple input, the self.det_multi_input should be True. Default: False. + """ + self.det_multi_input = det_multi_input + + def opt_model(self, model, input_shape, save_dir, data_type, task_type): + """Convert the model graph to an optimized pbmodel by using opt tool. + + Args: + model: The input model graph. + input_shape(list): The input shape of model. + save_dir: Where to save the pbmodel. + data_type: Data type, fp32 or int8. + task_type: Task type, cls, det or seg, different task models need to use different quantization strategies. + Returns: + pbmodel_file: The path of optimized pbmodel. + """ + + if task_type == 'cls': + model_file, param_file = save_cls_model( + model=model, + input_shape=input_shape, + save_dir=save_dir, + data_type=data_type) + + elif task_type == 'det': + model_file, param_file = save_det_model( + model=model, + input_shape=input_shape, + save_dir=save_dir, + data_type=data_type, + det_multi_input=self.det_multi_input) + + elif task_type == 'seg': + model_file, param_file = save_seg_model( + model=model, + input_shape=input_shape, + save_dir=save_dir, + data_type=data_type) + + else: + assert task_type in ['cls', 'det', 'seg' + ], f'task_type must be one of [cls, det, seg]' + + pb_model = os.path.join(save_dir, f'{data_type}pbmodel') + if not os.path.exists(pb_model): + os.makedirs(pb_model) + + cmd = f'{self.opt_path} --model_file={model_file} --param_file={param_file} --optimize_out_type=protobuf --optimize_out={pb_model} --valid_targets=arm' + print(f'commands:{cmd}') + m = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + out = m.communicate() + print(out, 'opt done!') + + pbmodel_file = os.path.join(pb_model, 'model') + + return pbmodel_file + + def predict_latency(self, + model, + input_shape=[1, 3, 224, 224], + save_dir='', + data_type='int8', + task_type='cls'): + """predict the latency of the model + + Args: + model: The input model graph. + input_shape(list): The input shape of model. Default: [1,3,224,224]. + save_dir: Where to save the pbmodel. + data_type: Data type, fp32 or int8. Default : int8 + task_type: Task type, cls, det or seg, different task models need to use different quantization strategies. Default: cls. + Returns: + latency(float): The latency of the pbmodel. + """ + + assert data_type in ['fp32', 'int8' + ], f'data_type must be one of [fp32, int8]' + assert task_type in ['cls', 'det', 'seg' + ], f'task_type must be one of [cls, det, seg]' + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pbmodel_file = self.opt_model( + model=model, + input_shape=input_shape, + save_dir=save_dir, + data_type=data_type, + task_type=task_type) + + paddle.enable_static() + with open(pbmodel_file, "rb") as f: + program_desc_str = f.read() + program = paddle.fluid.proto.framework_pb2.ProgramDesc.FromString( + program_desc_str) + fluid_program = paddle.fluid.framework.Program.parse_from_string( + program_desc_str) + + graph = paddleslim.core.GraphWrapper(fluid_program) + latency = 0.0 + for op in graph.ops(): + param_key = get_key_from_op(op) + if param_key != '': + assert param_key in self.table_dict, f'{param_key} is not in the tabel.' + latency += self.table_dict[param_key] + + return latency diff --git a/tests/test_latency_predictor.py b/tests/test_latency_predictor.py new file mode 100644 index 00000000..6b6131e9 --- /dev/null +++ b/tests/test_latency_predictor.py @@ -0,0 +1,388 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys, os +sys.path.append("../") +import unittest +import paddle +import paddleslim +from paddleslim.analysis import LatencyPredictor, TableLatencyPredictor +from paddle.vision.models import mobilenet_v1, mobilenet_v2 +from paddle.nn import Conv2D, BatchNorm2D, ReLU, LayerNorm +import subprocess + +opt_tool = 'opt_ubuntu' # use in linux +# opt_tool = 'opt_M1_mac' # use in mac with M1 chip +# opt_tool = 'opt_intel_mac' # use in mac with intel chip + +if not os.path.exists(opt_tool): + subprocess.call( + f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}', + shell=True) + subprocess.call(f'chmod +x {opt_tool}', shell=True) + + +def channel_shuffle(x, groups): + batch_size, num_channels, height, width = x.shape[0:4] + channels_per_group = num_channels // groups + + x = paddle.reshape( + x=x, shape=[batch_size, groups, channels_per_group, height, width]) + + x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4]) + + x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width]) + return x + + +class ModelCase1(paddle.nn.Layer): + def __init__(self): + super(ModelCase1, self).__init__() + self.conv1 = Conv2D(58, 58, 1) + self.conv2 = Conv2D(58, 58, 1) + + def forward(self, inputs): + x1, x2 = paddle.split( + inputs, + num_or_sections=[inputs.shape[1] // 2, inputs.shape[1] // 2], + axis=1) + x1 = self.conv1(x1) + x2 = self.conv2(x2) + out = paddle.concat([x1, x2], axis=1) + return channel_shuffle(out, 2) + + +class ModelCase2(paddle.nn.Layer): + def __init__(self): + super(ModelCase2, self).__init__() + self.conv1 = Conv2D(3, 24, 3, stride=2, padding=1) + + def forward(self, inputs): + image = inputs['image'] + + return self.conv1(image) + + +class ModelCase3(paddle.nn.Layer): + def __init__(self): + super(ModelCase3, self).__init__() + self.conv1 = Conv2D(3, 24, 3, stride=2, padding=1) + + def forward(self, inputs): + image = inputs['image'] + im_shape = inputs['im_shape'] + scale_factor = inputs['scale_factor'] + + return self.conv1(image), im_shape, scale_factor + + +class ModelCase4(paddle.nn.Layer): + def __init__(self): + super(ModelCase4, self).__init__() + self.bn1 = BatchNorm2D(3) + self.ln1 = LayerNorm([3 * 16 * 16]) + self.relu1 = ReLU() + self.fc1 = paddle.nn.Linear(3 * 16 * 16, 3 * 16 * 16) + + def forward(self, inputs): + x = self.bn1(inputs) + x = paddle.reshape(x, [1, 3 * 16 * 16]) + x = self.ln1(x) + x = self.fc1(x) + x = paddle.fluid.layers.unsqueeze(input=x, axes=[2]) + x = self.relu1(x) + y = paddle.fluid.layers.fill_constant( + x.shape, dtype=paddle.float32, value=1) + x = paddle.stack([x, y], axis=3) + x = paddle.slice(x, axes=[0], starts=[0], ends=[1]) + x = paddle.exp(x) + y += paddle.fluid.layers.uniform_random(y.shape) + y = paddle.fluid.layers.reduce_mean(y, dim=1, keep_dim=True) + return x + y + + +class ModelCase5(paddle.nn.Layer): + def __init__(self): + super(ModelCase5, self).__init__() + self.bn1 = BatchNorm2D(255) + + def forward(self, inputs): + image = inputs['image'] + image = self.bn1(image) + img_size = paddle.fluid.data( + name='img_size', shape=[None, 2], dtype='int64') + anchors = [10, 13, 16, 30, 33, 23] + boxes, scores = paddle.fluid.layers.yolo_box( + x=image, + img_size=img_size, + class_num=80, + anchors=anchors, + conf_thresh=0.01, + downsample_ratio=32) + out = paddle.fluid.layers.matrix_nms( + bboxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + post_threshold=0.1, + nms_top_k=400, + keep_top_k=200, + normalized=False) + box, var = paddle.fluid.layers.prior_box( + input=image, image=image, min_sizes=[2.], clip=True, flip=True) + return boxes, scores, box, var, out + + +class TestCase1(unittest.TestCase): + def test_case1(self): + paddle.disable_static() + model = mobilenet_v1() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='cls') + assert latency > 0 + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='int8', + task_type='cls') + assert latency > 0 + + +class TestCase2(unittest.TestCase): + def test_case2(self): + paddle.disable_static() + model = mobilenet_v2() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='cls') + assert latency > 0 + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='int8', + task_type='cls') + assert latency > 0 + + +class TestCase3(unittest.TestCase): + def test_case3(self): + paddle.disable_static() + model = mobilenet_v2() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + pred = LatencyPredictor() + pbmodel_file = predictor.opt_model( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='cls') + paddle.enable_static() + with open(pbmodel_file, "rb") as f: + program_desc_str = f.read() + fluid_program = paddle.fluid.framework.Program.parse_from_string( + program_desc_str) + graph = paddleslim.core.GraphWrapper(fluid_program) + graph_keys = pred._get_key_info_from_graph(graph=graph) + assert len(graph_keys) > 0 + + +class TestCase4(unittest.TestCase): + def test_case4(self): + paddle.disable_static() + model = ModelCase1() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + latency = predictor.predict_latency( + model, + input_shape=[1, 116, 28, 28], + save_dir='./model', + data_type='fp32', + task_type='cls') + assert latency > 0 + + +class TestCase5(unittest.TestCase): + def test_case5(self): + paddle.disable_static() + model = mobilenet_v1() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='seg') + assert latency > 0 + + +class TestCase6(unittest.TestCase): + def test_case6(self): + paddle.disable_static() + model = ModelCase2() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + pbmodel_file = predictor.opt_model( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='int8', + task_type='det') + assert os.path.exists(pbmodel_file) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='det') + assert latency > 0 + + +class TestCase7(unittest.TestCase): + def test_case7(self): + paddle.disable_static() + model = ModelCase3() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + predictor.set_det_multi_input(det_multi_input=True) + latency = predictor.predict_latency( + model, + input_shape=[1, 3, 224, 224], + save_dir='./model', + data_type='fp32', + task_type='det') + assert latency > 0 + + +class TestCase8(unittest.TestCase): + def test_case8(self): + paddle.disable_static() + model = ModelCase4() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + pbmodel_file = predictor.opt_model( + model, + input_shape=[1, 3, 16, 16], + save_dir='./model', + data_type='int8', + task_type='cls') + paddle.enable_static() + with open(pbmodel_file, "rb") as f: + program_desc_str = f.read() + fluid_program = paddle.fluid.framework.Program.parse_from_string( + program_desc_str) + graph = paddleslim.core.GraphWrapper(fluid_program) + graph_keys = predictor._get_key_info_from_graph(graph=graph) + assert len(graph_keys) > 0 + + +class TestCase9(unittest.TestCase): + def test_case9(self): + paddle.disable_static() + model = ModelCase5() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + pbmodel_file = predictor.opt_model( + model, + input_shape=[1, 255, 13, 13], + save_dir='./model', + data_type='fp32', + task_type='det') + paddle.enable_static() + with open(pbmodel_file, "rb") as f: + program_desc_str = f.read() + fluid_program = paddle.fluid.framework.Program.parse_from_string( + program_desc_str) + graph = paddleslim.core.GraphWrapper(fluid_program) + graph_keys = predictor._get_key_info_from_graph(graph=graph) + assert len(graph_keys) > 0 + + +class TestCase10(unittest.TestCase): + def test_case10(self): + paddle.disable_static() + model = ModelCase1() + predictor = TableLatencyPredictor( + f'./{opt_tool}', + hardware='845', + threads=4, + power_mode=3, + batchsize=1) + pbmodel_file = predictor.opt_model( + model, + input_shape=[1, 116, 28, 28], + save_dir='./model', + data_type='int8', + task_type='seg') + paddle.enable_static() + with open(pbmodel_file, "rb") as f: + program_desc_str = f.read() + fluid_program = paddle.fluid.framework.Program.parse_from_string( + program_desc_str) + graph = paddleslim.core.GraphWrapper(fluid_program) + graph_keys = predictor._get_key_info_from_graph(graph=graph) + assert len(graph_keys) > 0 + + +if __name__ == '__main__': + unittest.main() -- GitLab