From 456b9fed2a50d269785126dc1092abb8e37b809e Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 30 Sep 2022 15:22:12 +0800 Subject: [PATCH] update act paddle inference demo (#1432) Co-authored-by: ceci3 --- example/auto_compression/detection/README.md | 64 ++- .../detection/paddle_inference_eval.py | 498 ++++++++++++++++++ .../detection/paddle_trt_infer.py | 323 ------------ .../image_classification/README.md | 70 +-- .../image_classification/infer.py | 232 -------- .../paddle_inference_eval.py | 250 +++++++++ example/auto_compression/nlp/README.md | 36 +- .../{infer.py => paddle_inference_eval.py} | 365 +++++++------ .../pytorch_huggingface/README.md | 52 +- .../{infer.py => paddle_inference_eval.py} | 298 ++++++----- .../pytorch_yolo_series/README.md | 79 ++- .../paddle_inference_eval.py | 472 +++++++++++++++++ .../semantic_segmentation/README.md | 141 ++--- .../{infer.py => paddle_inference_eval.py} | 241 +++++---- paddleslim/common/load_model.py | 1 - 15 files changed, 1953 insertions(+), 1169 deletions(-) create mode 100644 example/auto_compression/detection/paddle_inference_eval.py delete mode 100644 example/auto_compression/detection/paddle_trt_infer.py delete mode 100644 example/auto_compression/image_classification/infer.py create mode 100644 example/auto_compression/image_classification/paddle_inference_eval.py rename example/auto_compression/nlp/{infer.py => paddle_inference_eval.py} (54%) rename example/auto_compression/pytorch_huggingface/{infer.py => paddle_inference_eval.py} (51%) create mode 100644 example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py rename example/auto_compression/semantic_segmentation/{infer.py => paddle_inference_eval.py} (53%) diff --git a/example/auto_compression/detection/README.md b/example/auto_compression/detection/README.md index 01ced5b6..4ca20249 100644 --- a/example/auto_compression/detection/README.md +++ b/example/auto_compression/detection/README.md @@ -7,8 +7,7 @@ - [3.1 环境准备](#31-准备环境) - [3.2 准备数据集](#32-准备数据集) - [3.3 准备预测模型](#33-准备预测模型) - - [3.4 测试模型精度](#34-测试模型精度) - - [3.5 自动压缩并产出模型](#35-自动压缩并产出模型) + - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) - [4.预测部署](#4预测部署) - [5.FAQ](5FAQ) @@ -110,23 +109,52 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log - --config_path=./configs/ppyoloe_l_qat_dis.yaml --save_dir='./output/' ``` -#### 3.5 测试模型精度 -使用eval.py脚本得到模型的mAP: -``` -export CUDA_VISIBLE_DEVICES=0 -python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml -``` +## 4.预测部署 -**注意**: -- 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。 +#### 4.1 Paddle Inference 验证性能 -## 4.预测部署 +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 + +以下字段用于配置预测参数: -- 如果模型包含NMS,可以参考[PaddleDetection部署教程](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy),GPU上量化模型开启TensorRT并设置trt_int8模式进行部署。 +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 | +| reader_config | eval时模型reader的配置文件路径 | +| image_file | 如果只测试单张图片效果,直接根据image_file指定图片路径 | +| device | 使用GPU或者CPU预测,可选CPU/GPU | +| use_trt | 是否使用 TesorRT 预测引擎 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```enable_mkldnn```,而使用```GPU```预测 | +| cpu_threads | CPU预测时,使用CPU线程数量,默认10 | +| precision | 预测精度,包括`fp32/fp16/int8` | -- 模型为PPYOLOE,同时不包含NMS,使用以下预测demo进行部署: - - Paddle-TensorRT C++部署 + +- TensorRT预测: + +环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) + +```shell +python paddle_inference_eval.py \ + --model_path=models/ppyoloe_crn_l_300e_coco_quant \ + --reader_config=configs/yoloe_reader.yml \ + --use_trt=True \ + --precision=int8 +``` + +- MKLDNN预测: + +```shell +python paddle_inference_eval.py \ + --model_path=models/ppyoloe_crn_l_300e_coco_quant \ + --reader_config=configs/yoloe_reader.yml \ + --device=CPU \ + --use_mkldnn=True \ + --cpu_threads=10 \ + --precision=int8 +``` + +- 模型为PPYOLOE,同时不包含NMS,可以使用C++预测demo进行测速: 进入[cpp_infer](./cpp_infer_ppyoloe)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer_ppyoloe/README.md)进行准备环境及编译,然后开始测试: ```shell @@ -136,14 +164,6 @@ python eval.py --config_path=./configs/ppyoloe_l_qat_dis.yaml ./build/trt_run --model_file ppyoloe_s_quant/model.pdmodel --params_file ppyoloe_s_quant/model.pdiparams --run_mode=trt_int8 ``` - - Paddle-TensorRT Python部署: - - 首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署: - ```shell - python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8 - ``` - ## 5.FAQ - - 如果想对模型进行离线量化,可进入[Detection模型离线量化示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/detection)中进行实验。 diff --git a/example/auto_compression/detection/paddle_inference_eval.py b/example/auto_compression/detection/paddle_inference_eval.py new file mode 100644 index 00000000..b318b8de --- /dev/null +++ b/example/auto_compression/detection/paddle_inference_eval.py @@ -0,0 +1,498 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import time +import sys +import cv2 +import numpy as np + +import paddle +from paddle.inference import Config +from paddle.inference import create_predictor +from ppdet.core.workspace import load_config, create +from ppdet.metrics import COCOMetric + +from post_process import PPYOLOEPostProcess + + +def argsparser(): + """ + argsparser func + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", type=str, help="inference model filepath") + parser.add_argument( + "--image_file", + type=str, + default=None, + help="image path, if set image_file, it will not eval coco.") + parser.add_argument( + "--reader_config", + type=str, + default=None, + help="path of datset and reader config.") + parser.add_argument( + "--benchmark", + type=bool, + default=False, + help="Whether run benchmark or not.") + parser.add_argument( + "--use_trt", + type=bool, + default=False, + help="Whether use TensorRT or not.") + parser.add_argument( + "--precision", + type=str, + default="paddle", + help="mode of running(fp32/fp16/int8)") + parser.add_argument( + "--device", + type=str, + default="GPU", + help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU", + ) + parser.add_argument( + "--use_dynamic_shape", + type=bool, + default=True, + help="Whether use dynamic shape or not.") + parser.add_argument( + "--use_mkldnn", + type=bool, + default=False, + help="Whether use mkldnn or not.") + parser.add_argument( + "--cpu_threads", type=int, default=10, help="Num of cpu threads.") + parser.add_argument("--img_shape", type=int, default=640, help="input_size") + parser.add_argument( + '--include_nms', + type=bool, + default=True, + help="Whether include nms or not.") + + return parser + + +CLASS_LABEL = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' +] + + +def generate_scale(im, target_shape, keep_ratio=True): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + if keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(target_shape) + target_size_max = np.max(target_shape) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = target_shape + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +def image_preprocess(img_path, target_shape): + """ + image_preprocess func + """ + img = cv2.imread(img_path) + im_scale_y, im_scale_x = generate_scale(img, target_shape, keep_ratio=False) + img = cv2.resize( + img, (target_shape[0], target_shape[0]), + interpolation=cv2.INTER_LANCZOS4) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, [2, 0, 1]) / 255 + img = np.expand_dims(img, 0) + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + scale_factor = np.array([[im_scale_y, im_scale_x]]) + return img.astype(np.float32), scale_factor.astype(np.float32) + + +def get_color_map_list(num_classes): + """ + get_color_map_list func + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j) + color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j) + color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_box(image_file, results, class_label, threshold=0.5): + """ + draw_box func + """ + srcimg = cv2.imread(image_file, 1) + for i in range(len(results)): + color_list = get_color_map_list(len(class_label)) + clsid2color = {} + classid, conf = int(results[i, 0]), results[i, 1] + if conf < threshold: + continue + xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int( + results[i, 4]), int(results[i, 5]) + + if classid not in clsid2color: + clsid2color[classid] = color_list[classid] + color = tuple(clsid2color[classid]) + + cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2) + print(class_label[classid] + ": " + str(round(conf, 3))) + cv2.putText( + srcimg, + class_label[classid] + ":" + str(round(conf, 3)), + (xmin, ymin - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 255, 0), + thickness=2, ) + return srcimg + + +def load_predictor( + model_dir, + precision="fp32", + use_trt=False, + use_mkldnn=False, + batch_size=1, + device="CPU", + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + cpu_threads=1, ): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + precision (str): mode of running(fp32/fp16/int8) + use_trt (bool): whether use TensorRT or not. + use_mkldnn (bool): whether use MKLDNN or not in CPU. + device (str): Choose the device you want to run, it can be: CPU/GPU, default is CPU + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need device == 'GPU'. + """ + rerun_flag = False + if device != "GPU" and use_trt: + raise ValueError( + "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}". + format(precision, device)) + config = Config( + os.path.join(model_dir, "model.pdmodel"), + os.path.join(model_dir, "model.pdiparams")) + if device == "GPU": + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + config.switch_ir_optim() + if use_mkldnn: + config.enable_mkldnn() + if precision == "int8": + config.enable_mkldnn_int8( + {"conv2d", "depthwise_conv2d", "transpose2", "pool2d"}) + + precision_map = { + "int8": Config.Precision.Int8, + "fp32": Config.Precision.Float32, + "fp16": Config.Precision.Half, + } + if precision in precision_map.keys() and use_trt: + config.enable_tensorrt_engine( + workspace_size=(1 << 25) * batch_size, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[precision], + use_static=True, + use_calib_mode=False, ) + + if use_dynamic_shape: + dynamic_shape_file = os.path.join(FLAGS.model_path, + "dynamic_shape.txt") + if os.path.exists(dynamic_shape_file): + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, + True) + print("trt set dynamic shape done!") + else: + config.collect_shape_range_info(dynamic_shape_file) + print("Start collect dynamic shape...") + rerun_flag = True + + # enable shared memory + config.enable_memory_optim() + predictor = create_predictor(config) + return predictor, rerun_flag + + +def get_current_memory_mb(): + """ + It is used to Obtain the memory usage of the CPU and GPU during the running of the program. + And this function Current program is time-consuming. + """ + try: + pkg.require('pynvml') + except: + from pip._internal import main + main(['install', 'pynvml']) + try: + pkg.require('psutil') + except: + from pip._internal import main + main(['install', 'psutil']) + try: + pkg.require('GPUtil') + except: + from pip._internal import main + main(['install', 'GPUtil']) + import pynvml + import psutil + import GPUtil + + gpu_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", 0)) + + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + cpu_mem = info.uss / 1024.0 / 1024.0 + gpu_mem = 0 + gpu_percent = 0 + gpus = GPUtil.getGPUs() + if gpu_id is not None and len(gpus) > 0: + gpu_percent = gpus[gpu_id].load + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem = meminfo.used / 1024.0 / 1024.0 + return round(cpu_mem, 4), round(gpu_mem, 4) + + +def predict_image(predictor, + image_file, + image_shape=[640, 640], + warmup=1, + repeats=1, + threshold=0.5): + """ + predict image main func + """ + img, scale_factor = image_preprocess(image_file, image_shape) + inputs = {} + inputs["image"] = img + if include_nms: + inputs['scale_factor'] = scale_factor + input_names = predictor.get_input_names() + for i, _ in enumerate(input_names): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + predictor.run() + + np_boxes, np_boxes_num = None, None + cpu_mems, gpu_mems = 0, 0 + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + for i in range(repeats): + start_time = time.time() + predictor.run() + output_names = predictor.get_output_names() + boxes_tensor = predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if FLAGS.include_nms: + boxes_num = predictor.get_output_handle(output_names[1]) + np_boxes_num = boxes_num.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + + time_avg = predict_time / repeats + print("[Benchmark]Avg cpu_mem:{} MB, avg gpu_mem: {} MB".format( + cpu_mems / repeats, gpu_mems / repeats)) + print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format( + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + if not FLAGS.include_nms: + postprocess = PPYOLOEPostProcess(score_threshold=0.3, nms_threshold=0.6) + res = postprocess(np_boxes, scale_factor) + else: + res = {'bbox': np_boxes, 'bbox_num': np_boxes_num} + res_img = draw_box( + image_file, res["bbox"], CLASS_LABEL, threshold=threshold) + cv2.imwrite("result.jpg", res_img) + + +def eval(predictor, val_loader, metric, rerun_flag=False): + """ + eval main func + """ + cpu_mems, gpu_mems = 0, 0 + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + sample_nums = len(val_loader) + input_names = predictor.get_input_names() + output_names = predictor.get_output_names() + boxes_tensor = predictor.get_output_handle(output_names[0]) + boxes_num = predictor.get_output_handle(output_names[1]) + for batch_id, data in enumerate(val_loader): + data_all = {k: np.array(v) for k, v in data.items()} + for i, _ in enumerate(input_names): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(data_all[input_names[i]]) + start_time = time.time() + predictor.run() + np_boxes = boxes_tensor.copy_to_cpu() + if FLAGS.include_nms: + np_boxes_num = boxes_num.copy_to_cpu() + if rerun_flag: + return + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + if not FLAGS.include_nms: + postprocess = PPYOLOEPostProcess( + score_threshold=0.3, nms_threshold=0.6) + res = postprocess(np_boxes, data_all['scale_factor']) + else: + res = {'bbox': np_boxes, 'bbox_num': np_boxes_num} + metric.update(data_all, res) + if batch_id % 100 == 0: + print("Eval iter:", batch_id) + sys.stdout.flush() + metric.accumulate() + metric.log() + map_res = metric.get_results() + metric.reset() + time_avg = predict_time / sample_nums + print("[Benchmark]Avg cpu_mem:{} MB, avg gpu_mem: {} MB".format( + cpu_mems / sample_nums, gpu_mems / sample_nums)) + print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format( + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + print("[Benchmark] COCO mAP: {}".format(map_res["bbox"][0])) + sys.stdout.flush() + + +def main(): + """ + main func + """ + predictor, rerun_flag = load_predictor( + FLAGS.model_path, + device=FLAGS.device, + use_trt=FLAGS.use_trt, + use_mkldnn=FLAGS.use_mkldnn, + precision=FLAGS.precision, + use_dynamic_shape=FLAGS.use_dynamic_shape, + cpu_threads=FLAGS.cpu_threads) + + if FLAGS.image_file: + warmup, repeats = 1, 1 + if FLAGS.benchmark: + warmup, repeats = 50, 100 + predict_image( + predictor, + FLAGS.image_file, + image_shape=[FLAGS.img_shape, FLAGS.img_shape], + warmup=warmup, + repeats=repeats) + else: + reader_cfg = load_config(FLAGS.reader_config) + + dataset = reader_cfg["EvalDataset"] + global val_loader + val_loader = create("EvalReader")(reader_cfg["EvalDataset"], + reader_cfg["worker_num"], + return_list=True) + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType="bbox") + eval(predictor, val_loader, metric, rerun_flag=rerun_flag) + + if rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) + + +if __name__ == "__main__": + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + # DataLoader need run on cpu + paddle.set_device("cpu") + + main() diff --git a/example/auto_compression/detection/paddle_trt_infer.py b/example/auto_compression/detection/paddle_trt_infer.py deleted file mode 100644 index d6d1e66c..00000000 --- a/example/auto_compression/detection/paddle_trt_infer.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import cv2 -import numpy as np -import argparse -import time - -from paddle.inference import Config -from paddle.inference import create_predictor - -from post_process import PPYOLOEPostProcess - -CLASS_LABEL = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' -] - - -def generate_scale(im, target_shape, keep_ratio=True): - """ - Args: - im (np.ndarray): image (np.ndarray) - Returns: - im_scale_x: the resize ratio of X - im_scale_y: the resize ratio of Y - """ - origin_shape = im.shape[:2] - if keep_ratio: - im_size_min = np.min(origin_shape) - im_size_max = np.max(origin_shape) - target_size_min = np.min(target_shape) - target_size_max = np.max(target_shape) - im_scale = float(target_size_min) / float(im_size_min) - if np.round(im_scale * im_size_max) > target_size_max: - im_scale = float(target_size_max) / float(im_size_max) - im_scale_x = im_scale - im_scale_y = im_scale - else: - resize_h, resize_w = target_shape - im_scale_y = resize_h / float(origin_shape[0]) - im_scale_x = resize_w / float(origin_shape[1]) - return im_scale_y, im_scale_x - - -def image_preprocess(img_path, target_shape): - img = cv2.imread(img_path) - im_scale_y, im_scale_x = generate_scale(img, target_shape, keep_ratio=False) - img = cv2.resize( - img, (target_shape[0], target_shape[0]), - interpolation=cv2.INTER_LANCZOS4) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = np.transpose(img, [2, 0, 1]) / 255 - img = np.expand_dims(img, 0) - img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) - img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) - img -= img_mean - img /= img_std - scale_factor = np.array([[im_scale_y, im_scale_x]]) - return img.astype(np.float32), scale_factor.astype(np.float32) - - -def get_color_map_list(num_classes): - color_map = num_classes * [0, 0, 0] - for i in range(0, num_classes): - j = 0 - lab = i - while lab: - color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) - color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) - color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) - j += 1 - lab >>= 3 - color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] - return color_map - - -def draw_box(image_file, results, class_label, threshold=0.5): - srcimg = cv2.imread(image_file, 1) - for i in range(len(results)): - color_list = get_color_map_list(len(class_label)) - clsid2color = {} - classid, conf = int(results[i, 0]), results[i, 1] - if conf < threshold: - continue - xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int( - results[i, 4]), int(results[i, 5]) - - if classid not in clsid2color: - clsid2color[classid] = color_list[classid] - color = tuple(clsid2color[classid]) - - cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2) - print(class_label[classid] + ': ' + str(round(conf, 3))) - cv2.putText( - srcimg, - class_label[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10), - cv2.FONT_HERSHEY_SIMPLEX, - 0.8, (0, 255, 0), - thickness=2) - return srcimg - - -def load_predictor(model_dir, - run_mode='paddle', - batch_size=1, - device='CPU', - min_subgraph_size=3, - use_dynamic_shape=False, - trt_min_shape=3, - trt_max_shape=1280, - trt_opt_shape=640, - trt_calib_mode=False, - cpu_threads=1, - enable_mkldnn=False, - enable_mkldnn_bfloat16=False, - delete_shuffle_pass=False): - """set AnalysisConfig, generate AnalysisPredictor - Args: - model_dir (str): root path of __model__ and __params__ - device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU - run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8) - use_dynamic_shape (bool): use dynamic shape or not - trt_min_shape (int): min shape for dynamic shape in trt - trt_max_shape (int): max shape for dynamic shape in trt - trt_opt_shape (int): opt shape for dynamic shape in trt - trt_calib_mode (bool): If the model is produced by TRT offline quantitative - calibration, trt_calib_mode need to set True - delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. - Used by action model. - Returns: - predictor (PaddlePredictor): AnalysisPredictor - Raises: - ValueError: predict by TensorRT need device == 'GPU'. - """ - if device != 'GPU' and run_mode != 'paddle': - raise ValueError( - "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}" - .format(run_mode, device)) - config = Config( - os.path.join(model_dir, 'model.pdmodel'), - os.path.join(model_dir, 'model.pdiparams')) - if device == 'GPU': - # initial GPU memory(M), device ID - config.enable_use_gpu(200, 0) - # optimize graph and fuse op - config.switch_ir_optim(True) - elif device == 'XPU': - config.enable_lite_engine() - config.enable_xpu(10 * 1024 * 1024) - else: - config.disable_gpu() - config.set_cpu_math_library_num_threads(cpu_threads) - if enable_mkldnn: - try: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() - if enable_mkldnn_bfloat16: - config.enable_mkldnn_bfloat16() - except Exception as e: - print( - "The current environment does not support `mkldnn`, so disable mkldnn." - ) - pass - - precision_map = { - 'trt_int8': Config.Precision.Int8, - 'trt_fp32': Config.Precision.Float32, - 'trt_fp16': Config.Precision.Half - } - if run_mode in precision_map.keys(): - config.enable_tensorrt_engine( - workspace_size=(1 << 25) * batch_size, - max_batch_size=batch_size, - min_subgraph_size=min_subgraph_size, - precision_mode=precision_map[run_mode], - use_static=False, - use_calib_mode=trt_calib_mode) - - if use_dynamic_shape: - min_input_shape = { - 'image': [batch_size, 3, trt_min_shape, trt_min_shape] - } - max_input_shape = { - 'image': [batch_size, 3, trt_max_shape, trt_max_shape] - } - opt_input_shape = { - 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] - } - config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, - opt_input_shape) - print('trt set dynamic shape done!') - - # enable shared memory - config.enable_memory_optim() - # disable feed, fetch OP, needed by zero_copy_run - config.switch_use_feed_fetch_ops(False) - if delete_shuffle_pass: - config.delete_pass("shuffle_channel_detect_pass") - predictor = create_predictor(config) - return predictor - - -def predict_image(predictor, - image_file, - image_shape=[640, 640], - warmup=1, - repeats=1, - threshold=0.5, - include_nms=True): - img, scale_factor = image_preprocess(image_file, image_shape) - inputs = {} - inputs['image'] = img - if include_nms: - inputs['scale_factor'] = scale_factor - input_names = predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = predictor.get_input_handle(input_names[i]) - input_tensor.copy_from_cpu(inputs[input_names[i]]) - - for i in range(warmup): - predictor.run() - - np_boxes, np_boxes_num = None, None - predict_time = 0. - time_min = float("inf") - time_max = float('-inf') - for i in range(repeats): - start_time = time.time() - predictor.run() - output_names = predictor.get_output_names() - boxes_tensor = predictor.get_output_handle(output_names[0]) - np_boxes = boxes_tensor.copy_to_cpu() - if include_nms: - boxes_num = predictor.get_output_handle(output_names[1]) - np_boxes_num = boxes_num.copy_to_cpu() - end_time = time.time() - timed = end_time - start_time - time_min = min(time_min, timed) - time_max = max(time_max, timed) - predict_time += timed - - time_avg = predict_time / repeats - print('Inference time(ms): min={}, max={}, avg={}'.format( - round(time_min * 1000, 2), - round(time_max * 1000, 1), round(time_avg * 1000, 1))) - if not include_nms: - postprocess = PPYOLOEPostProcess(score_threshold=0.3, nms_threshold=0.6) - res = postprocess(np_boxes, scale_factor) - else: - res = {'bbox': np_boxes, 'bbox_num': np_boxes_num} - res_img = draw_box( - image_file, res['bbox'], CLASS_LABEL, threshold=threshold) - cv2.imwrite('result.jpg', res_img) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument( - '--image_file', type=str, default=None, help="image path") - parser.add_argument( - '--model_path', type=str, help="inference model filepath") - parser.add_argument( - '--benchmark', - type=bool, - default=False, - help="Whether run benchmark or not.") - parser.add_argument( - '--run_mode', - type=str, - default='paddle', - help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)") - parser.add_argument( - '--device', - type=str, - default='GPU', - help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU" - ) - parser.add_argument('--img_shape', type=int, default=640, help="input_size") - parser.add_argument( - '--include_nms', - type=bool, - default=True, - help="Whether include nms or not.") - args = parser.parse_args() - - predictor = load_predictor( - args.model_path, run_mode=args.run_mode, device=args.device) - warmup, repeats = 1, 1 - if args.benchmark: - warmup, repeats = 50, 100 - predict_image( - predictor, - args.image_file, - image_shape=[args.img_shape, args.img_shape], - warmup=warmup, - repeats=repeats, - include_nms=args.include_nms) diff --git a/example/auto_compression/image_classification/README.md b/example/auto_compression/image_classification/README.md index 23a0e381..973c04e5 100644 --- a/example/auto_compression/image_classification/README.md +++ b/example/auto_compression/image_classification/README.md @@ -113,48 +113,56 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' - 注意 ```learning rate``` 与 ```batch size``` 呈线性关系,这里单卡 ```batch size``` 为32,对应的 ```learning rate``` 为0.015,那么如果 ```batch size``` 减小4倍改为8,```learning rate``` 也需除以4;多卡时 ```batch size``` 为32,```learning rate``` 需乘上卡数。所以改变 ```batch size``` 或改变训练卡数都需要对应修改 ```learning rate```。 -**验证精度** - -根据训练log可以看到模型验证的精度,若需再次验证精度,修改配置文件```./configs/MobileNetV1/qat_dis.yaml```中所需验证模型的文件夹路径及模型和参数名称```model_dir, model_filename, params_filename```,然后使用以下命令进行验证: - -```shell -export CUDA_VISIBLE_DEVICES=0 -python eval.py --config_path='./configs/MobileNetV1/qat_dis.yaml' -``` - ## 4.预测部署 -#### 4.1 Python预测推理 -环境配置:若使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) +#### 4.1 Paddle Inference 验证性能 + +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 以下字段用于配置预测参数: -- ```inference_model_dir```:inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件 -- ```model_filename```:inference_model_dir文件夹下的模型文件名称 -- ```params_filename```:inference_model_dir文件夹下的参数文件名称 - -- ```batch_size```:预测一个batch的大小 -- ```image_size```:输入图像的大小 -- ```use_tensorrt```:是否使用 TesorRT 预测引擎 -- ```use_gpu```:是否使用 GPU 预测 -- ```enable_mkldnn```:是否启用```MKL-DNN```加速库,注意```enable_mkldnn```与```use_gpu```同时为```True```时,将忽略```enable_mkldnn```,而使用```GPU```预测 -- ```use_fp16```:是否启用```FP16``` -- ```use_int8```:是否启用```INT8``` + +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件 | +| model_filename | inference_model_dir文件夹下的模型文件名称 | +| params_filename | inference_model_dir文件夹下的参数文件名称 | +| data_path | 数据集路径 | +| batch_size | 预测一个batch的大小 | +| image_size | 输入图像的大小 | +| use_gpu | 是否使用 GPU 预测 | +| use_trt | 是否使用 TesorRT 预测引擎 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```use_mkldnn```,而使用```GPU```预测 | +| cpu_num_threads | CPU预测时,使用CPU线程数量,默认10 | +| use_fp16 | 使用TensorRT时,是否启用```FP16``` | +| use_int8 | 是否启用```INT8``` | 注意: - 请注意模型的输入数据尺寸,如InceptionV3输入尺寸为299,部分模型需要修改参数:```image_size``` -- 如果希望提升评测模型速度,使用 ```GPU``` 评测时,建议开启 ```TensorRT``` 加速预测,使用 ```CPU``` 评测时,建议开启 ```MKL-DNN``` 加速预测 -准备好inference模型后,使用以下命令进行预测: +- TensorRT预测: + +环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) + +```shell +python paddle_inference_eval.py \ + --model_path=models/ResNet50_vd_QAT \ + --use_trt=True \ + --use_int8=True \ + --use_gpu=True \ + --data_path=./dataset/ILSVRC2012/ +``` + +- MKLDNN预测: + ```shell -python infer.py --model_dir='MobileNetV1_infer' \ ---model_filename='inference.pdmodel' \ ---model_filename='inference.pdiparams' \ ---eval=True \ ---use_gpu=True \ ---enable_mkldnn=True \ ---use_int8=True +python paddle_inference_eval.py \ + --model_path=models/ResNet50_vd_QAT \ + --data_path=./dataset/ILSVRC2012/ \ + --cpu_num_threads=10 \ + --use_mkldnn=True \ + --use_int8=True ``` #### 4.2 PaddleLite端侧部署 diff --git a/example/auto_compression/image_classification/infer.py b/example/auto_compression/image_classification/infer.py deleted file mode 100644 index 1d727db1..00000000 --- a/example/auto_compression/image_classification/infer.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import numpy as np -import cv2 -import time -import sys -import argparse -import yaml -from tqdm import tqdm - -from utils import preprocess, postprocess -import paddle -from paddle.inference import create_predictor -from paddleslim.common import load_config -from paddle.io import DataLoader -from imagenet_reader import ImageNetDataset, process_image - - -def argsparser(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - '--model_dir', - type=str, - default='./MobileNetV1_infer', - help='model directory') - parser.add_argument( - '--model_filename', - type=str, - default='inference.pdmodel', - help='model file name') - parser.add_argument( - '--params_filename', - type=str, - default='inference.pdiparams', - help='params file name') - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--img_size', type=int, default=224) - parser.add_argument('--resize_size', type=int, default=256) - parser.add_argument( - '--eval', type=bool, default=False, help='Whether to evaluate') - parser.add_argument('--data_path', type=str, default='./ILSVRC2012/') - parser.add_argument( - '--use_gpu', type=bool, default=False, help='Whether to use gpu') - parser.add_argument( - '--enable_mkldnn', - type=bool, - default=False, - help='Whether to use mkldnn') - parser.add_argument( - '--cpu_num_threads', type=int, default=10, help='Number of cpu threads') - parser.add_argument( - '--use_fp16', type=bool, default=False, help='Whether to use fp16') - parser.add_argument( - '--use_int8', type=bool, default=False, help='Whether to use int8') - parser.add_argument( - '--use_tensorrt', - type=bool, - default=True, - help='Whether to use tensorrt') - parser.add_argument( - '--enable_profile', - type=bool, - default=False, - help='Whether to enable profile') - parser.add_argument('--gpu_mem', type=int, default=8000, help='GPU memory') - parser.add_argument('--ir_optim', type=bool, default=True) - return parser - - -def eval_reader(data_dir, batch_size, crop_size, resize_size): - val_reader = ImageNetDataset( - mode='val', - data_dir=data_dir, - crop_size=crop_size, - resize_size=resize_size) - val_loader = DataLoader( - val_reader, - batch_size=args.batch_size, - shuffle=False, - drop_last=False, - num_workers=0) - return val_loader - - -class Predictor(object): - def __init__(self, args): - - # HALF precission predict only work when using tensorrt - if args.use_fp16 is True: - assert args.use_tensorrt is True - self.args = args - - self.paddle_predictor = self.create_paddle_predictor() - input_names = self.paddle_predictor.get_input_names() - self.input_tensor = self.paddle_predictor.get_input_handle(input_names[ - 0]) - - output_names = self.paddle_predictor.get_output_names() - self.output_tensor = self.paddle_predictor.get_output_handle( - output_names[0]) - - def create_paddle_predictor(self): - inference_model_dir = self.args.model_dir - model_file = os.path.join(inference_model_dir, self.args.model_filename) - params_file = os.path.join(inference_model_dir, - self.args.params_filename) - config = paddle.inference.Config(model_file, params_file) - precision = paddle.inference.Config.Precision.Float32 - if self.args.use_int8: - precision = paddle.inference.Config.Precision.Int8 - elif self.args.use_fp16: - precision = paddle.inference.Config.Precision.Half - - if self.args.use_gpu: - config.enable_use_gpu(self.args.gpu_mem, 0) - else: - config.disable_gpu() - if self.args.enable_mkldnn: - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) - config.enable_mkldnn() - config.set_cpu_math_library_num_threads(self.args.cpu_num_threads) - - if self.args.enable_profile: - config.enable_profile() - config.switch_ir_optim(self.args.ir_optim) # default true - if self.args.use_tensorrt: - config.enable_tensorrt_engine( - precision_mode=precision, - max_batch_size=self.args.batch_size, - workspace_size=1 << 30, - min_subgraph_size=30, - use_calib_mode=False) - - config.enable_memory_optim() - # use zero copy - config.switch_use_feed_fetch_ops(False) - predictor = create_predictor(config) - - return predictor - - def predict(self): - test_num = 1000 - test_time = 0.0 - for i in range(0, test_num + 10): - inputs = np.random.rand(self.args.batch_size, 3, self.args.img_size, - self.args.img_size).astype(np.float32) - start_time = time.time() - self.input_tensor.copy_from_cpu(inputs) - self.paddle_predictor.run() - batch_output = self.output_tensor.copy_to_cpu().flatten() - if i >= 10: - test_time += time.time() - start_time - time.sleep(0.01) # sleep for T4 GPU - - fp_message = "FP16" if self.args.use_fp16 else "FP32" - fp_message = "INT8" if self.args.use_int8 else fp_message - trt_msg = "using tensorrt" if self.args.use_tensorrt else "not using tensorrt" - print("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( - trt_msg, fp_message, args.batch_size, 1000 * test_time / test_num)) - - def eval(self): - if os.path.exists(self.args.data_path): - val_loader = eval_reader( - self.args.data_path, - batch_size=self.args.batch_size, - crop_size=self.args.img_size, - resize_size=self.args.resize_size) - else: - image = np.ones((1, 3, self.args.img_size, - self.args.img_size)).astype(np.float32) - label = None - val_loader = [[image, label]] - results = [] - with tqdm( - total=len(val_loader), - bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: - for batch_id, (image, label) in enumerate(val_loader): - input_names = self.paddle_predictor.get_input_names() - input_tensor = self.paddle_predictor.get_input_handle( - input_names[0]) - output_names = self.paddle_predictor.get_output_names() - output_tensor = self.paddle_predictor.get_output_handle( - output_names[0]) - - image = np.array(image) - - input_tensor.copy_from_cpu(image) - self.paddle_predictor.run() - batch_output = output_tensor.copy_to_cpu() - sort_array = batch_output.argsort(axis=1) - top_1_pred = sort_array[:, -1:][:, ::-1] - if label is None: - results.append(top_1_pred) - break - label = np.array(label) - top_1 = np.mean(label == top_1_pred) - top_5_pred = sort_array[:, -5:][:, ::-1] - acc_num = 0 - for i in range(len(label)): - if label[i][0] in top_5_pred[i]: - acc_num += 1 - top_5 = float(acc_num) / len(label) - results.append([top_1, top_5]) - - result = np.mean(np.array(results), axis=0) - t.update() - print('Evaluation result: {}'.format(result[0])) - - -if __name__ == "__main__": - parser = argsparser() - global args - args = parser.parse_args() - predictor = Predictor(args) - predictor.predict() - if args.eval: - predictor.eval() diff --git a/example/auto_compression/image_classification/paddle_inference_eval.py b/example/auto_compression/image_classification/paddle_inference_eval.py new file mode 100644 index 00000000..e086a953 --- /dev/null +++ b/example/auto_compression/image_classification/paddle_inference_eval.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import argparse +import numpy as np +import cv2 +import yaml + +import paddle +from paddle.inference import create_predictor +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset + + +def argsparser(): + """ + argsparser func + """ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_path", + type=str, + default="./MobileNetV1_infer", + help="model directory") + parser.add_argument( + "--model_filename", + type=str, + default="inference.pdmodel", + help="model file name") + parser.add_argument( + "--params_filename", + type=str, + default="inference.pdiparams", + help="params file name") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--img_size", type=int, default=224) + parser.add_argument("--resize_size", type=int, default=256) + parser.add_argument( + "--data_path", type=str, default="./dataset/ILSVRC2012/") + parser.add_argument( + "--use_gpu", type=bool, default=False, help="Whether to use gpu") + parser.add_argument( + "--use_trt", type=bool, default=False, help="Whether to use tensorrt") + parser.add_argument( + "--use_mkldnn", type=bool, default=False, help="Whether to use mkldnn") + parser.add_argument( + "--cpu_num_threads", type=int, default=10, help="Number of cpu threads") + parser.add_argument( + "--use_fp16", type=bool, default=False, help="Whether to use fp16") + parser.add_argument( + "--use_int8", type=bool, default=False, help="Whether to use int8") + parser.add_argument("--gpu_mem", type=int, default=8000, help="GPU memory") + parser.add_argument("--ir_optim", type=bool, default=True) + parser.add_argument( + "--use_dynamic_shape", + type=bool, + default=True, + help="Whether use dynamic shape or not.") + return parser + + +def eval_reader(data_dir, batch_size, crop_size, resize_size): + """ + eval reader func + """ + val_reader = ImageNetDataset( + mode="val", + data_dir=data_dir, + crop_size=crop_size, + resize_size=resize_size) + val_loader = DataLoader( + val_reader, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader + + +class Predictor(object): + """ + Paddle Inference Predictor class + """ + + def __init__(self): + # HALF precission predict only work when using tensorrt + if args.use_fp16 is True: + assert args.use_trt is True + + self.rerun_flag = False + self.paddle_predictor = self._create_paddle_predictor() + input_names = self.paddle_predictor.get_input_names() + self.input_tensor = self.paddle_predictor.get_input_handle(input_names[ + 0]) + + output_names = self.paddle_predictor.get_output_names() + self.output_tensor = self.paddle_predictor.get_output_handle( + output_names[0]) + + def _create_paddle_predictor(self): + inference_model_dir = args.model_path + model_file = os.path.join(inference_model_dir, args.model_filename) + params_file = os.path.join(inference_model_dir, args.params_filename) + config = paddle.inference.Config(model_file, params_file) + precision = paddle.inference.Config.Precision.Float32 + if args.use_int8: + precision = paddle.inference.Config.Precision.Int8 + elif args.use_fp16: + precision = paddle.inference.Config.Precision.Half + + if args.use_gpu: + config.enable_use_gpu(args.gpu_mem, 0) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(args.cpu_num_threads) + config.switch_ir_optim() + if args.use_mkldnn: + config.enable_mkldnn() + if args.use_int8: + config.enable_mkldnn_int8( + {"conv2d", "depthwise_conv2d", "transpose2", "pool2d"}) + + config.switch_ir_optim(args.ir_optim) # default true + if args.use_trt: + config.enable_tensorrt_engine( + precision_mode=precision, + max_batch_size=args.batch_size, + workspace_size=1 << 30, + min_subgraph_size=30, + use_static=True, + use_calib_mode=False, ) + + if args.use_dynamic_shape: + dynamic_shape_file = os.path.join(inference_model_dir, + "dynamic_shape.txt") + if os.path.exists(dynamic_shape_file): + config.enable_tuned_tensorrt_dynamic_shape( + dynamic_shape_file, True) + print("trt set dynamic shape done!") + else: + config.collect_shape_range_info(dynamic_shape_file) + print("Start collect dynamic shape...") + self.rerun_flag = True + + config.enable_memory_optim() + predictor = create_predictor(config) + + return predictor + + def eval(self): + """ + eval func + """ + if os.path.exists(args.data_path): + val_loader = eval_reader( + args.data_path, + batch_size=args.batch_size, + crop_size=args.img_size, + resize_size=args.resize_size) + else: + image = np.ones( + (1, 3, args.img_size, args.img_size)).astype(np.float32) + label = None + val_loader = [[image, label]] + results = [] + input_names = self.paddle_predictor.get_input_names() + input_tensor = self.paddle_predictor.get_input_handle(input_names[0]) + output_names = self.paddle_predictor.get_output_names() + output_tensor = self.paddle_predictor.get_output_handle(output_names[0]) + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + sample_nums = len(val_loader) + for batch_id, (image, label) in enumerate(val_loader): + image = np.array(image) + + input_tensor.copy_from_cpu(image) + start_time = time.time() + self.paddle_predictor.run() + batch_output = output_tensor.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + if self.rerun_flag: + return + sort_array = batch_output.argsort(axis=1) + top_1_pred = sort_array[:, -1:][:, ::-1] + if label is None: + results.append(top_1_pred) + break + label = np.array(label) + top_1 = np.mean(label == top_1_pred) + top_5_pred = sort_array[:, -5:][:, ::-1] + acc_num = 0 + for i, _ in enumerate(label): + if label[i][0] in top_5_pred[i]: + acc_num += 1 + top_5 = float(acc_num) / len(label) + results.append([top_1, top_5]) + if batch_id % 100 == 0: + print("Eval iter:", batch_id) + sys.stdout.flush() + + result = np.mean(np.array(results), axis=0) + fp_message = "FP16" if args.use_fp16 else "FP32" + fp_message = "INT8" if args.use_int8 else fp_message + print_msg = "Paddle" + if args.use_trt: + print_msg = "using TensorRT" + elif args.use_mkldnn: + print_msg = "using MKLDNN" + time_avg = predict_time / sample_nums + print( + "[Benchmark]{}\t{}\tbatch size: {}.Inference time(ms): min={}, max={}, avg={}". + format( + print_msg, + fp_message, + args.batch_size, + round(time_min * 1000, 2), + round(time_max * 1000, 1), + round(time_avg * 1000, 1), )) + print("[Benchmark] Evaluation acc result: {}".format(result[0])) + sys.stdout.flush() + + +if __name__ == "__main__": + parser = argsparser() + args = parser.parse_args() + predictor = Predictor() + predictor.eval() + if predictor.rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) diff --git a/example/auto_compression/nlp/README.md b/example/auto_compression/nlp/README.md index 38724e6b..af1a5cf3 100644 --- a/example/auto_compression/nlp/README.md +++ b/example/auto_compression/nlp/README.md @@ -194,25 +194,41 @@ Quantization: ## 5. 预测部署 -- Python部署: +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 -首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 -然后使用[infer.py](./infer.py)进行部署: +- TensorRT预测: -本示例将以ERNIE 3.0-Medium模型、afqmc数据集的为例,介绍如何利用Paddle—TensorRT测试压缩后模型的精度和速度。 +环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) -精度测试方法: +首先下载量化好的模型: ```shell -python infer.py --task_name='afqmc' --model_path='./save_ernie3.0_afqmc/' --device='gpu' --use_trt --int8 +wget https://bj.bcebos.com/v1/paddle-slim-models/act/save_ppminilm_afqmc_new_calib.tar +tar -xf save_ppminilm_afqmc_new_calib.tar ``` -速度测试方法 ```shell -python infer.py --task_name='afqmc' --model_path='./save_ernie3.0_afqmc/' --device='gpu' --use_trt --int8 --perf +python paddle_inference_eval.py \ + --model_path=save_ernie3_afqmc_new_cablib \ + --model_filename=infer.pdmodel \ + --params_filename=infer.pdiparams \ + --task_name='afqmc' \ + --use_trt \ + --precision=int8 ``` -- [PP-MiniLM Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/model_compression/pp-minilm) -- [ERNIE-3.0 Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/ernie-3.0) +- MKLDNN预测: + +```shell +python paddle_inference_eval.py \ + --model_path=save_ernie3_afqmc_new_cablib \ + --model_filename=infer.pdmodel \ + --params_filename=infer.pdiparams \ + --task_name='afqmc' \ + --device=cpu \ + --use_mkldnn=True \ + --cpu_threads=10 \ + --precision=int8 +``` ## 6. FAQ diff --git a/example/auto_compression/nlp/infer.py b/example/auto_compression/nlp/paddle_inference_eval.py similarity index 54% rename from example/auto_compression/nlp/infer.py rename to example/auto_compression/nlp/paddle_inference_eval.py index 01f67ebe..f48e2069 100644 --- a/example/auto_compression/nlp/infer.py +++ b/example/auto_compression/nlp/paddle_inference_eval.py @@ -45,96 +45,42 @@ METRIC_CLASSES = { } -def convert_example(example, dataset, tokenizer, label_list, - max_seq_length=512): - assert dataset in ['glue', 'clue' - ], "This demo only supports for dataset glue or clue" - """Convert a glue example into necessary features.""" - if dataset == 'glue': - # `label_list == None` is for regression task - label_dtype = "int64" if label_list else "float32" - # Get the label - label = example['labels'] - label = np.array([label], dtype=label_dtype) - # Convert raw text to feature - example = tokenizer(example['sentence'], max_seq_len=max_seq_length) - - return example['input_ids'], example['token_type_ids'], label - - else: #if dataset == 'clue': - # `label_list == None` is for regression task - label_dtype = "int64" if label_list else "float32" - # Get the label - example['label'] = np.array( - example["label"], dtype="int64").reshape((-1, 1)) - label = example['label'] - # Convert raw text to feature - if 'keyword' in example: # CSL - sentence1 = " ".join(example['keyword']) - example = { - 'sentence1': sentence1, - 'sentence2': example['abst'], - 'label': example['label'] - } - elif 'target' in example: # wsc - text, query, pronoun, query_idx, pronoun_idx = example[ - 'text'], example['target']['span1_text'], example['target'][ - 'span2_text'], example['target']['span1_index'], example[ - 'target']['span2_index'] - text_list = list(text) - assert text[pronoun_idx:(pronoun_idx + len( - pronoun))] == pronoun, "pronoun: {}".format(pronoun) - assert text[query_idx:(query_idx + len(query) - )] == query, "query: {}".format(query) - if pronoun_idx > query_idx: - text_list.insert(query_idx, "_") - text_list.insert(query_idx + len(query) + 1, "_") - text_list.insert(pronoun_idx + 2, "[") - text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") - else: - text_list.insert(pronoun_idx, "[") - text_list.insert(pronoun_idx + len(pronoun) + 1, "]") - text_list.insert(query_idx + 2, "_") - text_list.insert(query_idx + len(query) + 2 + 1, "_") - text = "".join(text_list) - example['sentence'] = text - if tokenizer is None: - return example - if 'sentence' in example: - example = tokenizer(example['sentence'], max_seq_len=max_seq_length) - elif 'sentence1' in example: - example = tokenizer( - example['sentence1'], - text_pair=example['sentence2'], - max_seq_len=max_seq_length) - return example['input_ids'], example['token_type_ids'], label - - def parse_args(): + """ + parse_args func + """ parser = argparse.ArgumentParser() - - # Required parameters + parser.add_argument( + "--model_path", + default="./afqmc", + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--model_filename", + type=str, + default="inference.pdmodel", + help="model file name") + parser.add_argument( + "--params_filename", + type=str, + default="inference.pdiparams", + help="params file name") parser.add_argument( "--task_name", - default='afqmc', + default="afqmc", type=str, help="The name of the task to perform predict, selected in the list: " + ", ".join(METRIC_CLASSES.keys()), ) parser.add_argument( "--dataset", - default='clue', + default="clue", type=str, help="The dataset of model.", ) - parser.add_argument( - "--model_path", - default='./quant_models/model', - type=str, - required=True, - help="The path prefix of inference model to be used.", ) parser.add_argument( "--device", default="gpu", - choices=["gpu", "cpu", "xpu"], + choices=["gpu", "cpu"], help="Device selected for inference.", ) parser.add_argument( "--batch_size", @@ -154,25 +100,101 @@ def parse_args(): help="Warmup steps for performance test.", ) parser.add_argument( "--use_trt", - action='store_true', + action="store_true", help="Whether to use inference engin TensorRT.", ) parser.add_argument( - "--perf", - action='store_true', - help="Whether to test performance.", ) + "--precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "int8"], + help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.", + ) parser.add_argument( - "--int8", - action='store_true', - help="Whether to use int8 inference.", ) + "--use_mkldnn", + type=bool, + default=False, + help="Whether use mkldnn or not.") parser.add_argument( - "--fp16", - action='store_true', - help="Whether to use float16 inference.", ) + "--cpu_threads", type=int, default=1, help="Num of cpu threads.") args = parser.parse_args() return args +def _convert_example(example, + dataset, + tokenizer, + label_list, + max_seq_length=512): + assert dataset in ["glue", "clue" + ], "This demo only supports for dataset glue or clue" + """Convert a glue example into necessary features.""" + if dataset == "glue": + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example["labels"] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + example = tokenizer(example["sentence"], max_seq_len=max_seq_length) + + return example["input_ids"], example["token_type_ids"], label + + else: # if dataset == 'clue': + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + example["label"] = np.array( + example["label"], dtype="int64").reshape((-1, 1)) + label = example["label"] + # Convert raw text to feature + if "keyword" in example: # CSL + sentence1 = " ".join(example["keyword"]) + example = { + "sentence1": sentence1, + "sentence2": example["abst"], + "label": example["label"] + } + elif "target" in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = ( + example["text"], + example["target"]["span1_text"], + example["target"]["span2_text"], + example["target"]["span1_index"], + example["target"]["span2_index"], ) + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len( + pronoun))] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example["sentence"] = text + if tokenizer is None: + return example + if "sentence" in example: + example = tokenizer(example["sentence"], max_seq_len=max_seq_length) + elif "sentence1" in example: + example = tokenizer( + example["sentence1"], + text_pair=example["sentence2"], + max_seq_len=max_seq_length) + return example["input_ids"], example["token_type_ids"], label + + class Predictor(object): + """ + Inference Predictor class + """ + def __init__(self, predictor, input_handles, output_handles): self.predictor = predictor self.input_handles = input_handles @@ -180,60 +202,50 @@ class Predictor(object): @classmethod def create_predictor(cls, args): - config = paddle.inference.Config(args.model_path + "infer.pdmodel", - args.model_path + "infer.pdiparams") + """ + create_predictor func + """ + cls.rerun_flag = False + config = paddle.inference.Config( + os.path.join(args.model_path, args.model_filename), + os.path.join(args.model_path, args.params_filename)) if args.device == "gpu": # set GPU configs accordingly config.enable_use_gpu(100, 0) cls.device = paddle.set_device("gpu") - elif args.device == "cpu": - # set CPU configs accordingly, - # such as enable_mkldnn, set_cpu_math_library_num_threads + else: config.disable_gpu() - cls.device = paddle.set_device("cpu") - elif args.device == "xpu": - # set XPU configs accordingly - config.enable_xpu(100) - if args.use_trt: - if args.int8: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Int8, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - elif args.fp16: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Half, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - else: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Float32, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - print("Enable TensorRT is: {}".format( - config.tensorrt_engine_enabled())) + config.set_cpu_math_library_num_threads(args.cpu_threads) + config.switch_ir_optim() + if args.use_mkldnn: + config.enable_mkldnn() + if args.precision == "int8": + config.enable_mkldnn_int8() + + precision_map = { + "int8": inference.PrecisionType.Int8, + "fp32": inference.PrecisionType.Float32, + "fp16": inference.PrecisionType.Half, + } + if args.precision in precision_map.keys() and args.use_trt: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=args.batch_size, + min_subgraph_size=5, + precision_mode=precision_map[args.precision], + use_static=True, + use_calib_mode=False, ) dynamic_shape_file = os.path.join(args.model_path, - 'dynamic_shape.txt') + "dynamic_shape.txt") if os.path.exists(dynamic_shape_file): config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) - print('trt set dynamic shape done!') + print("trt set dynamic shape done!") else: config.collect_shape_range_info(dynamic_shape_file) - print( - 'Start collect dynamic shape... Please eval again to get real result in TensorRT' - ) - sys.exit() + print("Start collect dynamic shape...") + cls.rerun_flag = True predictor = paddle.inference.create_predictor(config) @@ -249,6 +261,9 @@ class Predictor(object): return cls(predictor, input_handles, output_handles) def predict_batch(self, data): + """ + predict from batch func + """ for input_field, input_handle in zip(data, self.input_handles): input_handle.copy_from_cpu(input_field) self.predictor.run() @@ -257,11 +272,11 @@ class Predictor(object): ] return output - def convert_predict_batch(self, args, data, tokenizer, batchify_fn, - label_list): + def _convert_predict_batch(self, args, data, tokenizer, batchify_fn, + label_list): examples = [] for example in data: - example = convert_example( + example = _convert_example( example, args.dataset, tokenizer, @@ -272,64 +287,82 @@ class Predictor(object): return examples def predict(self, dataset, tokenizer, batchify_fn, args): + """ + predict func + """ batches = [ dataset[idx:idx + args.batch_size] for idx in range(0, len(dataset), args.batch_size) ] - if args.perf: - for i, batch in enumerate(batches): - examples = self.convert_predict_batch( - args, batch, tokenizer, batchify_fn, dataset.label_list) - input_ids, segment_ids, label = batchify_fn(examples) - output = self.predict_batch([input_ids, segment_ids]) - if i > args.perf_warmup_steps: - break - start_time = time.time() - for i, batch in enumerate(batches): - examples = self.convert_predict_batch( - args, batch, tokenizer, batchify_fn, dataset.label_list) - input_ids, segment_ids, _ = batchify_fn(examples) - output = self.predict_batch([input_ids, segment_ids]) + for i, batch in enumerate(batches): + examples = self._convert_predict_batch( + args, batch, tokenizer, batchify_fn, dataset.label_list) + input_ids, segment_ids, label = batchify_fn(examples) + output = self.predict_batch([input_ids, segment_ids]) + if i > args.perf_warmup_steps: + break + if self.rerun_flag: + return + + metric = METRIC_CLASSES[args.task_name]() + metric.reset() + predict_time = 0.0 + for i, batch in enumerate(batches): + examples = self._convert_predict_batch( + args, batch, tokenizer, batchify_fn, dataset.label_list) + input_ids, segment_ids, label = batchify_fn(examples) + start_time = time.time() + output = self.predict_batch([input_ids, segment_ids]) end_time = time.time() - sequences_num = i * args.batch_size - print("task name: %s, time: %s qps/s, " % - (args.task_name, sequences_num / (end_time - start_time))) + predict_time += end_time - start_time + correct = metric.compute( + paddle.to_tensor(output), + paddle.to_tensor(np.array(label).flatten())) + metric.update(correct) - else: - metric = METRIC_CLASSES[args.task_name]() - metric.reset() - for i, batch in enumerate(batches): - examples = self.convert_predict_batch( - args, batch, tokenizer, batchify_fn, dataset.label_list) - input_ids, segment_ids, label = batchify_fn(examples) - output = self.predict_batch([input_ids, segment_ids]) - correct = metric.compute( - paddle.to_tensor(output), - paddle.to_tensor(np.array(label).flatten())) - metric.update(correct) - - res = metric.accumulate() - print("task name: %s, acc: %s, \n" % (args.task_name, res), end='') + sequences_num = i * args.batch_size + print( + "[benchmark]task name: {}, batch size: {} Inference time per batch: {}ms, qps: {}.". + format( + args.task_name, + args.batch_size, + round(predict_time * 1000 / i, 2), + round(sequences_num / predict_time, 2), )) + res = metric.accumulate() + print( + "[benchmark]task name: %s, acc: %s. \n" % (args.task_name, res), + end="") + sys.stdout.flush() def main(): + """ + main func + """ paddle.seed(42) args = parse_args() args.task_name = args.task_name.lower() + if args.use_mkldnn: + paddle.set_device("cpu") predictor = Predictor.create_predictor(args) - dev_ds = load_dataset('clue', args.task_name, splits='dev') + dev_ds = load_dataset("clue", args.task_name, splits="dev") tokenizer = AutoTokenizer.from_pretrained(args.model_path) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment - Stack(dtype="int64" if dev_ds.label_list else "float32") # label + Stack(dtype="int64" if dev_ds.label_list else "float32"), # label ): fn(samples) - outputs = predictor.predict(dev_ds, tokenizer, batchify_fn, args) + predictor.predict(dev_ds, tokenizer, batchify_fn, args) + if predictor.rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) if __name__ == "__main__": + paddle.set_device("cpu") main() diff --git a/example/auto_compression/pytorch_huggingface/README.md b/example/auto_compression/pytorch_huggingface/README.md index ac362463..b7cc1437 100644 --- a/example/auto_compression/pytorch_huggingface/README.md +++ b/example/auto_compression/pytorch_huggingface/README.md @@ -14,12 +14,9 @@ ## 1. 简介 飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,PaddleSlim的自动压缩功能可方便地用于各种框架的推理模型。 - 本示例将以[Pytorch](https://github.com/pytorch/pytorch)框架的自然语言处理模型为例,介绍如何自动压缩其他框架中的自然语言处理模型。本示例会利用[huggingface](https://github.com/huggingface/transformers)开源transformers库,将Pytorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为剪枝蒸馏和量化训练。 - - ## 2. Benchmark [BERT](https://arxiv.org/abs/1810.04805) (```Bidirectional Encoder Representations from Transformers```)以Transformer 编码器为网络基本组件,使用掩码语言模型(```Masked Language Model```)和邻接句子预测(```Next Sentence Prediction```)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的NLP任务,效果通常也较直接在下游的任务上训练的模型更优。此前BERT即在[GLUE](https://gluebenchmark.com/tasks)评测任务上取得了SOTA的结果。 @@ -192,41 +189,38 @@ python run.py --config_path=./configs/cola.yaml --eval True ## 4. 预测部署 -环境配置:若使用 Paddle TensorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 + -启动配置: +- TensorRT预测: -除需传入```task_name```任务名称,```model_name_or_path```模型名称,```model_path```保存inference模型的路径等基本参数外,还需根据预测环境传入预测参数: -- ```device```:默认为gpu,可选为gpu, cpu, xpu -- ```use_trt```:是否使用 TesorRT 预测引擎 -- ```int8```:是否启用```INT8``` -- ```fp16```:是否启用```FP16``` +环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) -准备好inference模型后,可以使用```infer.py```进行预测,如使用 TesorRT 预测引擎测试 FP32 模型: +首先下载量化好的模型: ```shell -python -u ./infer.py \ - --task_name cola \ - --model_name_or_path bert-base-cased \ - --model_path ./x2paddle_cola/model \ - --batch_size 1 \ - --max_seq_length 128 \ - --device gpu \ - --use_trt +wget https://bj.bcebos.com/v1/paddle-slim-models/act/x2paddle_cola_new_calib.tar +tar -xf x2paddle_cola_new_calib.tar ``` -如使用 TesorRT 预测引擎测试 INT8 模型: ```shell -python -u ./infer.py \ - --task_name cola \ - --model_name_or_path bert-base-cased \ - --model_path ./output/cola/model \ - --batch_size 1 \ - --max_seq_length 128 \ - --device gpu \ - --use_trt \ - --int8 +python paddle_inference_eval.py \ + --model_path=x2paddle_cola_new_calib \ + --use_trt \ + --precision=int8 \ + --batch_size=1 ``` +- MKLDNN预测: + +```shell +python paddle_inference_eval.py \ + --model_path=x2paddle_cola_new_calib \ + --device=cpu \ + --use_mkldnn=True \ + --cpu_threads=10 \ + --batch_size=1 \ + --precision=int8 +``` diff --git a/example/auto_compression/pytorch_huggingface/infer.py b/example/auto_compression/pytorch_huggingface/paddle_inference_eval.py similarity index 51% rename from example/auto_compression/pytorch_huggingface/infer.py rename to example/auto_compression/pytorch_huggingface/paddle_inference_eval.py index 7bbefb50..d17407cc 100644 --- a/example/auto_compression/pytorch_huggingface/infer.py +++ b/example/auto_compression/pytorch_huggingface/paddle_inference_eval.py @@ -22,9 +22,9 @@ import numpy as np import paddle from paddle import inference +from paddle.metric import Metric, Accuracy, Precision, Recall from paddlenlp.datasets import load_dataset from paddlenlp.data import Stack, Tuple, Pad -from paddle.metric import Metric, Accuracy, Precision, Recall from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer @@ -53,35 +53,46 @@ task_to_keys = { def parse_args(): + """ + parse_args func + """ parser = argparse.ArgumentParser() - - # Required parameters + parser.add_argument( + "--model_path", + default="./x2paddle_cola", + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--model_filename", + type=str, + default="model.pdmodel", + help="model file name") + parser.add_argument( + "--params_filename", + type=str, + default="model.pdiparams", + help="params file name") parser.add_argument( "--task_name", - default='cola', + default="cola", type=str, help="The name of the task to perform predict, selected in the list: " + ", ".join(METRIC_CLASSES.keys()), ) parser.add_argument( "--model_type", - default='bert-base-cased', + default="bert-base-cased", type=str, help="Model type selected in bert.") parser.add_argument( "--model_name_or_path", - default='bert-base-cased', + default="bert-base-cased", type=str, help="The directory or name of model.", ) - parser.add_argument( - "--model_path", - default='./quant_models/model', - type=str, - required=True, - help="The path prefix of inference model to be used.", ) parser.add_argument( "--device", default="gpu", - choices=["gpu", "cpu", "xpu"], + choices=["gpu", "cpu"], help="Device selected for inference.", ) parser.add_argument( "--batch_size", @@ -101,42 +112,45 @@ def parse_args(): help="Warmup steps for performance test.", ) parser.add_argument( "--use_trt", - action='store_true', + action="store_true", help="Whether to use inference engin TensorRT.", ) parser.add_argument( - "--perf", - action='store_true', - help="Whether to test performance.", ) + "--precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "int8"], + help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.", + ) parser.add_argument( - "--int8", - action='store_true', - help="Whether to use int8 inference.", ) + "--use_mkldnn", + type=bool, + default=False, + help="Whether use mkldnn or not.") parser.add_argument( - "--fp16", - action='store_true', - help="Whether to use float16 inference.", ) + "--cpu_threads", type=int, default=1, help="Num of cpu threads.") args = parser.parse_args() return args -def convert_example(example, - tokenizer, - label_list, - max_seq_length=512, - task_name=None, - is_test=False, - padding='max_length', - return_attention_mask=True): +def _convert_example( + example, + tokenizer, + label_list, + max_seq_length=512, + task_name=None, + is_test=False, + padding="max_length", + return_attention_mask=True, ): if not is_test: # `label_list == None` is for regression task label_dtype = "int64" if label_list else "float32" # Get the label - label = example['labels'] + label = example["labels"] label = np.array([label], dtype=label_dtype) # Convert raw text to feature sentence1_key, sentence2_key = task_to_keys[task_name] - texts = ((example[sentence1_key], ) if sentence2_key is None else - (example[sentence1_key], example[sentence2_key])) + texts = (example[sentence1_key], ) if sentence2_key is None else ( + example[sentence1_key], example[sentence2_key]) example = tokenizer( *texts, max_seq_len=max_seq_length, @@ -144,19 +158,23 @@ def convert_example(example, return_attention_mask=return_attention_mask) if not is_test: if return_attention_mask: - return example['input_ids'], example['attention_mask'], example[ - 'token_type_ids'], label + return example["input_ids"], example["attention_mask"], example[ + "token_type_ids"], label else: - return example['input_ids'], example['token_type_ids'], label + return example["input_ids"], example["token_type_ids"], label else: if return_attention_mask: - return example['input_ids'], example['attention_mask'], example[ - 'token_type_ids'] + return example["input_ids"], example["attention_mask"], example[ + "token_type_ids"] else: - return example['input_ids'], example['token_type_ids'] + return example["input_ids"], example["token_type_ids"] class Predictor(object): + """ + Inference Predictor class + """ + def __init__(self, predictor, input_handles, output_handles): self.predictor = predictor self.input_handles = input_handles @@ -164,60 +182,51 @@ class Predictor(object): @classmethod def create_predictor(cls, args): - config = paddle.inference.Config(args.model_path + ".pdmodel", - args.model_path + ".pdiparams") + """ + create_predictor func + """ + cls.rerun_flag = False + config = paddle.inference.Config( + os.path.join(args.model_path, args.model_filename), + os.path.join(args.model_path, args.params_filename)) if args.device == "gpu": # set GPU configs accordingly config.enable_use_gpu(100, 0) cls.device = paddle.set_device("gpu") - elif args.device == "cpu": - # set CPU configs accordingly, - # such as enable_mkldnn, set_cpu_math_library_num_threads + else: config.disable_gpu() - cls.device = paddle.set_device("cpu") - elif args.device == "xpu": - # set XPU configs accordingly - config.enable_xpu(100) - if args.use_trt: - if args.int8: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Int8, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - elif args.fp16: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Half, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - else: - config.enable_tensorrt_engine( - workspace_size=1 << 30, - precision_mode=inference.PrecisionType.Float32, - max_batch_size=args.batch_size, - min_subgraph_size=5, - use_static=False, - use_calib_mode=False) - print("Enable TensorRT is: {}".format( - config.tensorrt_engine_enabled())) - - model_dir = os.path.dirname(args.model_path) - dynamic_shape_file = os.path.join(model_dir, 'dynamic_shape.txt') + config.set_cpu_math_library_num_threads(args.cpu_threads) + config.switch_ir_optim() + if args.use_mkldnn: + config.enable_mkldnn() + if args.precision == "int8": + config.enable_mkldnn_int8( + {"fc", "reshape2", "transpose2", "slice"}) + + precision_map = { + "int8": inference.PrecisionType.Int8, + "fp32": inference.PrecisionType.Float32, + "fp16": inference.PrecisionType.Half, + } + if args.precision in precision_map.keys() and args.use_trt: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=args.batch_size, + min_subgraph_size=5, + precision_mode=precision_map[args.precision], + use_static=True, + use_calib_mode=False, ) + + dynamic_shape_file = os.path.join(args.model_path, + "dynamic_shape.txt") if os.path.exists(dynamic_shape_file): config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, True) - print('trt set dynamic shape done!') + print("trt set dynamic shape done!") else: config.collect_shape_range_info(dynamic_shape_file) - print( - 'Start collect dynamic shape... Please eval again to get real result in TensorRT' - ) - sys.exit() + print("Start collect dynamic shape...") + cls.rerun_flag = True predictor = paddle.inference.create_predictor(config) @@ -233,6 +242,9 @@ class Predictor(object): return cls(predictor, input_handles, output_handles) def predict(self, dataset, collate_fn, args): + """ + predict func + """ batch_sampler = paddle.io.BatchSampler( dataset, batch_size=args.batch_size, shuffle=False) data_loader = paddle.io.DataLoader( @@ -241,94 +253,92 @@ class Predictor(object): collate_fn=collate_fn, num_workers=0, return_list=True) - end_time = 0 - if args.perf: - for i, data in enumerate(data_loader): - for input_field, input_handle in zip(data, self.input_handles): - input_handle.copy_from_cpu(input_field.numpy( - ) if isinstance(input_field, paddle.Tensor) else - input_field) - - self.predictor.run() - - output = [ - output_handle.copy_to_cpu() - for output_handle in self.output_handles - ] - - if i > args.perf_warmup_steps: - break - - time1 = time.time() - for i, data in enumerate(data_loader): - for input_field, input_handle in zip(data, self.input_handles): - input_handle.copy_from_cpu(input_field.numpy( - ) if isinstance(input_field, paddle.Tensor) else - input_field) - self.predictor.run() - output = [ - output_handle.copy_to_cpu() - for output_handle in self.output_handles - ] - - sequences_num = i * args.batch_size - print("task name: %s, time: %s qps/s, " % - (args.task_name, sequences_num / (time.time() - time1))) - - else: - metric = METRIC_CLASSES[args.task_name]() - metric.reset() - for i, data in enumerate(data_loader): - for input_field, input_handle in zip(data, self.input_handles): - input_handle.copy_from_cpu(input_field.numpy( - ) if isinstance(input_field, paddle.Tensor) else - input_field) - self.predictor.run() - output = [ - output_handle.copy_to_cpu() - for output_handle in self.output_handles - ] - - label = data[-1] - correct = metric.compute( - paddle.to_tensor(output[0]), - paddle.to_tensor(np.array(label).flatten())) - print(correct) - metric.update(correct) - res = metric.accumulate() - print("task name: %s, acc: %s, \n" % (args.task_name, res), end='') + for i, data in enumerate(data_loader): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() + for output_handle in self.output_handles + ] + if i > args.perf_warmup_steps: + break + if self.rerun_flag: + return + + metric = METRIC_CLASSES[args.task_name]() + metric.reset() + predict_time = 0.0 + for i, data in enumerate(data_loader): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + start_time = time.time() + self.predictor.run() + output = [ + output_handle.copy_to_cpu() + for output_handle in self.output_handles + ] + end_time = time.time() + predict_time += end_time - start_time + label = data[-1] + correct = metric.compute( + paddle.to_tensor(output[0]), + paddle.to_tensor(np.array(label).flatten())) + metric.update(correct) + + sequences_num = i * args.batch_size + print( + "[benchmark]task name: {}, batch size: {} Inference time per batch: {}ms, qps: {}.". + format( + args.task_name, + args.batch_size, + round(predict_time * 1000 / i, 2), + round(sequences_num / predict_time, 2), )) + res = metric.accumulate() + print( + "[benchmark]task name: %s, acc: %s. \n" % (args.task_name, res), + end="") + sys.stdout.flush() def main(): + """ + main func + """ paddle.seed(42) args = parse_args() + if args.use_mkldnn: + paddle.set_device("cpu") predictor = Predictor.create_predictor(args) args.task_name = args.task_name.lower() args.model_type = args.model_type.lower() - dev_ds = load_dataset('glue', args.task_name, splits='dev') + dev_ds = load_dataset("glue", args.task_name, splits="dev") tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) trans_func = partial( - convert_example, + _convert_example, tokenizer=tokenizer, label_list=dev_ds.label_list, max_seq_length=args.max_seq_length, task_name=args.task_name, - return_attention_mask=True) + return_attention_mask=True, ) dev_ds = dev_ds.map(trans_func) batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=tokenizer.pad_token_id), # input Pad(axis=0, pad_val=0), Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment - Stack(dtype="int64" if dev_ds.label_list else "float32") # label + Stack(dtype="int64" if dev_ds.label_list else "float32"), # label ): fn(samples) predictor.predict(dev_ds, batchify_fn, args) if __name__ == "__main__": + paddle.set_device("cpu") main() diff --git a/example/auto_compression/pytorch_yolo_series/README.md b/example/auto_compression/pytorch_yolo_series/README.md index eeb40c37..e75ac8b9 100644 --- a/example/auto_compression/pytorch_yolo_series/README.md +++ b/example/auto_compression/pytorch_yolo_series/README.md @@ -8,7 +8,6 @@ - [3.2 准备数据集](#32-准备数据集) - [3.3 准备预测模型](#33-准备预测模型) - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) - - [3.5 测试模型精度](#35-测试模型精度) - [4.预测部署](#4预测部署) - [5.FAQ](5FAQ) @@ -149,14 +148,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log - --config_path=./configs/yolov7_tiny_qat_dis.yaml --save_dir='./output/' ``` -#### 3.5 测试模型精度 - -修改[yolov7_qat_dis.yaml](./configs/yolov7_qat_dis.yaml)中`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP: -``` -export CUDA_VISIBLE_DEVICES=0 -python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml -``` - ## 4.预测部署 @@ -164,31 +155,60 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml ```shell ├── model.pdiparams # Paddle预测模型权重 ├── model.pdmodel # Paddle预测模型文件 -├── calibration_table.txt # Paddle量化后校准表 ├── ONNX │ ├── quant_model.onnx # 量化后转出的ONNX模型 │ ├── calibration.cache # TensorRT可以直接加载的校准表 ``` -#### 导出至ONNX使用TensorRT部署 +#### Paddle Inference部署测试 -加载`quant_model.onnx`和`calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[TensorRT部署](./TensorRT) +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 + +以下字段用于配置预测参数: + +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 | +| dataset_dir | eval时数据验证集路径, 默认`dataset/coco` | +| image_file | 如果只测试单张图片效果,直接根据image_file指定图片路径 | +| device | 使用GPU或者CPU预测,可选CPU/GPU | +| use_trt | 是否使用 TesorRT 预测引擎 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```enable_mkldnn```,而使用```GPU```预测 | +| cpu_threads | CPU预测时,使用CPU线程数量,默认10 | +| precision | 预测精度,包括`fp32/fp16/int8` | + + TensorRT Python部署: + +首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 + +然后使用[paddle_inference_eval.py](./paddle_inference_eval.py)进行部署: -- python测试: ```shell -cd TensorRT -python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ - --calibration_file=output/ONNX/calibration.cache \ - --image_file=../images/000000570688.jpg \ - --precision_mode=int8 +python paddle_inference_eval.py \ + --model_path=output \ + --reader_config=configs/yoloe_reader.yml \ + --use_trt=True \ + --precision=int8 ``` -- 速度测试 +- MKLDNN预测: + ```shell -trtexec --onnx=output/ONNX/quant_model.onnx --avgRuns=1000 --workspace=1024 --calib=output/ONNX/calibration.cache --int8 +python paddle_inference_eval.py \ + --model_path=output \ + --reader_config=configs/yoloe_reader.yml \ + --device=CPU \ + --use_mkldnn=True \ + --cpu_threads=10 \ + --precision=int8 +``` + +- 测试单张图片 + +```shell +python paddle_inference_eval.py --model_path=output --image_file=images/000000570688.jpg --use_trt=True --precision=int8 ``` -#### Paddle-TensorRT部署 - C++部署 进入[cpp_infer](./cpp_infer)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer/README.md)进行准备环境及编译,然后开始测试: @@ -199,13 +219,22 @@ bash compile.sh ./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8 ``` -- Python部署: +#### 导出至ONNX使用TensorRT部署 -首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 +加载`quant_model.onnx`和`calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[TensorRT部署](./TensorRT) + +- python测试: +```shell +cd TensorRT +python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ + --calibration_file=output/ONNX/calibration.cache \ + --image_file=../images/000000570688.jpg \ + --precision_mode=int8 +``` -然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署: +- 速度测试 ```shell -python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8 +trtexec --onnx=output/ONNX/quant_model.onnx --avgRuns=1000 --workspace=1024 --calib=output/ONNX/calibration.cache --int8 ``` ## 5.FAQ diff --git a/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py b/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py new file mode 100644 index 00000000..a1df31b7 --- /dev/null +++ b/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py @@ -0,0 +1,472 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import sys +import argparse +import cv2 +import numpy as np +from tqdm import tqdm +import pkg_resources as pkg + +import paddle +from paddle.inference import Config +from paddle.inference import create_predictor +from dataset import COCOValDataset +from post_process import YOLOPostProcess, coco_metric + + +def argsparser(): + """ + argsparser func + """ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_path", type=str, help="inference model filepath") + parser.add_argument( + "--image_file", + type=str, + default=None, + help="image path, if set image_file, it will not eval coco.") + parser.add_argument( + "--dataset_dir", + type=str, + default="dataset/coco", + help="COCO dataset dir.") + parser.add_argument( + "--val_image_dir", + type=str, + default="val2017", + help="COCO dataset val image dir.") + parser.add_argument( + "--val_anno_path", + type=str, + default="annotations/instances_val2017.json", + help="COCO dataset anno path.") + parser.add_argument( + "--benchmark", + type=bool, + default=False, + help="Whether run benchmark or not.") + parser.add_argument( + "--use_dynamic_shape", + type=bool, + default=True, + help="Whether use dynamic shape or not.") + parser.add_argument( + "--use_trt", + type=bool, + default=False, + help="Whether use TensorRT or not.") + parser.add_argument( + "--precision", + type=str, + default="paddle", + help="mode of running(fp32/fp16/int8)") + parser.add_argument( + "--device", + type=str, + default="GPU", + help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU", + ) + parser.add_argument( + "--arch", type=str, default="YOLOv5", help="architectures name.") + parser.add_argument("--img_shape", type=int, default=640, help="input_size") + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size of model input.") + parser.add_argument( + "--use_mkldnn", + type=bool, + default=False, + help="Whether use mkldnn or not.") + parser.add_argument( + "--cpu_threads", type=int, default=1, help="Num of cpu threads.") + return parser + + +CLASS_LABEL = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' +] + + +def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)): + """ + image preprocess func + """ + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, ).astype(np.float32) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img + + padded_img = padded_img[:, :, ::-1] + padded_img /= 255.0 + if mean is not None: + padded_img -= mean + if std is not None: + padded_img /= std + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +def get_color_map_list(num_classes): + """ + get_color_map_list func + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j) + color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j) + color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None): + """ + draw_box func + """ + color_list = get_color_map_list(len(class_names)) + for i, _ in enumerate(boxes): + box = boxes[i] + cls_id = int(cls_ids[i]) + color = tuple(color_list[cls_id]) + score = scores[i] + if score < conf: + continue + x0 = int(box[0]) + y0 = int(box[1]) + x1 = int(box[2]) + y1 = int(box[3]) + + text = "{}:{:.1f}%".format(class_names[cls_id], score * 100) + font = cv2.FONT_HERSHEY_SIMPLEX + + txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] + cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) + cv2.rectangle(img, (x0, y0 + 1), ( + x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])), color, -1) + cv2.putText( + img, + text, (x0, y0 + txt_size[1]), + font, + 0.8, (0, 255, 0), + thickness=2) + + return img + + +def get_current_memory_mb(): + """ + It is used to Obtain the memory usage of the CPU and GPU during the running of the program. + And this function Current program is time-consuming. + """ + import pynvml + import psutil + import GPUtil + + gpu_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", 0)) + + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + cpu_mem = info.uss / 1024.0 / 1024.0 + gpu_mem = 0 + gpu_percent = 0 + gpus = GPUtil.getGPUs() + if gpu_id is not None and len(gpus) > 0: + gpu_percent = gpus[gpu_id].load + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem = meminfo.used / 1024.0 / 1024.0 + return round(cpu_mem, 4), round(gpu_mem, 4) + + +def load_predictor( + model_dir, + precision="fp32", + use_trt=False, + use_mkldnn=False, + batch_size=1, + device="CPU", + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + cpu_threads=1, ): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + precision (str): mode of running(fp32/fp16/int8) + use_trt (bool): whether use TensorRT or not. + use_mkldnn (bool): whether use MKLDNN or not in CPU. + device (str): Choose the device you want to run, it can be: CPU/GPU, default is CPU + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need device == 'GPU'. + """ + rerun_flag = False + if device != "GPU" and use_trt: + raise ValueError( + "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}". + format(precision, device)) + config = Config( + os.path.join(model_dir, "model.pdmodel"), + os.path.join(model_dir, "model.pdiparams")) + if device == "GPU": + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + config.switch_ir_optim() + if use_mkldnn: + config.enable_mkldnn() + if precision == "int8": + config.enable_mkldnn_int8({"conv2d", "transpose2", "pool2d"}) + + precision_map = { + "int8": Config.Precision.Int8, + "fp32": Config.Precision.Float32, + "fp16": Config.Precision.Half, + } + if precision in precision_map.keys() and use_trt: + config.enable_tensorrt_engine( + workspace_size=(1 << 25) * batch_size, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[precision], + use_static=True, + use_calib_mode=False, ) + + if use_dynamic_shape: + dynamic_shape_file = os.path.join(FLAGS.model_path, + "dynamic_shape.txt") + if os.path.exists(dynamic_shape_file): + config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, + True) + print("trt set dynamic shape done!") + else: + config.collect_shape_range_info(dynamic_shape_file) + print("Start collect dynamic shape...") + rerun_flag = True + + # enable shared memory + config.enable_memory_optim() + predictor = create_predictor(config) + return predictor, rerun_flag + + +def eval(predictor, val_loader, anno_file, rerun_flag=False): + """ + eval main func + """ + bboxes_list, bbox_nums_list, image_id_list = [], [], [] + cpu_mems, gpu_mems = 0, 0 + sample_nums = len(val_loader) + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + input_names = predictor.get_input_names() + output_names = predictor.get_output_names() + boxes_tensor = predictor.get_output_handle(output_names[0]) + for batch_id, data in enumerate(val_loader): + data_all = {k: np.array(v) for k, v in data.items()} + inputs = {} + if FLAGS.arch == "YOLOv6": + inputs["x2paddle_image_arrays"] = data_all["image"] + else: + inputs["x2paddle_images"] = data_all["image"] + for i, _ in enumerate(input_names): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + start_time = time.time() + predictor.run() + outs = boxes_tensor.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + if rerun_flag: + return + postprocess = YOLOPostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs), data_all["scale_factor"]) + bboxes_list.append(res["bbox"]) + bbox_nums_list.append(res["bbox_num"]) + image_id_list.append(np.array(data_all["im_id"])) + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + if batch_id % 100 == 0: + print("Eval iter:", batch_id) + sys.stdout.flush() + print("[Benchmark]Avg cpu_mem:{} MB, avg gpu_mem: {} MB".format( + cpu_mems / sample_nums, gpu_mems / sample_nums)) + time_avg = predict_time / sample_nums + print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format( + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + + map_res = coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list) + print("[Benchmark] COCO mAP: {}".format(map_res[0])) + sys.stdout.flush() + + +def infer(predictor): + """ + infer image main func + """ + warmup, repeats = 1, 1 + if FLAGS.benchmark: + warmup, repeats = 50, 100 + origin_img = cv2.imread(FLAGS.image_file) + input_image, scale_factor = preprocess(origin_img, + [FLAGS.img_shape, FLAGS.img_shape]) + input_image = np.expand_dims(input_image, axis=0) + scale_factor = np.array([[scale_factor, scale_factor]]) + inputs = {} + if FLAGS.arch == "YOLOv6": + inputs["x2paddle_image_arrays"] = input_image + else: + inputs["x2paddle_images"] = input_image + input_names = predictor.get_input_names() + for i, _ in enumerate(input_names): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + predictor.run() + + np_boxes = None + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + cpu_mems, gpu_mems = 0, 0 + for i in range(repeats): + start_time = time.time() + predictor.run() + output_names = predictor.get_output_names() + boxes_tensor = predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + print("[Benchmark]Avg cpu_mem:{} MB, avg gpu_mem: {} MB".format( + cpu_mems / repeats, gpu_mems / repeats)) + + time_avg = predict_time / repeats + print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format( + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + postprocess = YOLOPostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np_boxes, scale_factor) + # Draw rectangles and labels on the original image + dets = res["bbox"] + if dets is not None: + final_boxes, final_scores, final_class = dets[:, 2:], dets[:, + 1], dets[:, + 0] + res_img = draw_box( + origin_img, + final_boxes, + final_scores, + final_class, + conf=0.5, + class_names=CLASS_LABEL) + cv2.imwrite("output.jpg", res_img) + print("The prediction results are saved in output.jpg.") + + +def main(): + """ + main func + """ + predictor, rerun_flag = load_predictor( + FLAGS.model_path, + device=FLAGS.device, + use_trt=FLAGS.use_trt, + use_mkldnn=FLAGS.use_mkldnn, + precision=FLAGS.precision, + use_dynamic_shape=FLAGS.use_dynamic_shape, + cpu_threads=FLAGS.cpu_threads, ) + + if FLAGS.image_file: + infer(predictor) + else: + dataset = COCOValDataset( + dataset_dir=FLAGS.dataset_dir, + image_dir=FLAGS.val_image_dir, + anno_path=FLAGS.val_anno_path) + anno_file = dataset.ann_file + val_loader = paddle.io.DataLoader( + dataset, batch_size=FLAGS.batch_size, drop_last=True) + eval(predictor, val_loader, anno_file, rerun_flag=rerun_flag) + + if rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) + + +if __name__ == "__main__": + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + # DataLoader need run on cpu + paddle.set_device("cpu") + + main() diff --git a/example/auto_compression/semantic_segmentation/README.md b/example/auto_compression/semantic_segmentation/README.md index 1e25cb75..6adec2f0 100644 --- a/example/auto_compression/semantic_segmentation/README.md +++ b/example/auto_compression/semantic_segmentation/README.md @@ -8,8 +8,7 @@ - [3.2 准备数据集](#32-准备数据集) - [3.3 准备预测模型](#33-准备预测模型) - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) -- [4.评估精度](#4评估精度) -- [5.预测部署](#5预测部署) +- [4.预测部署](#4预测部署) - [5.FAQ](5FAQ) ## 1.简介 @@ -156,104 +155,68 @@ python -m paddle.distributed.launch run.py --config_path='./configs/pp_humanseg/ 压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。 -## 4.评估精度 +## 4.预测部署 -本小节以人像分割模型和小数据集为例, 介绍如何在测试集上评估压缩后的模型. +#### 4.1 Paddle Inference 验证性能 -下载经过量化训练压缩后的推理模型: -``` -wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip -unzip pp_humanseg_qat.zip -``` +量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。 -通过以下命令下载人像分割示例数据: +以下字段用于配置预测参数: -```shell -cd ./data -python download_data.py mini_humanseg -cd - - -``` +| 参数名 | 含义 | +|:------:|:------:| +| model_path | inference 模型文件所在目录,该目录下需要有文件 .pdmodel 和 .pdiparams 两个文件 | +| model_filename | inference_model_dir文件夹下的模型文件名称 | +| params_filename | inference_model_dir文件夹下的参数文件名称 | +| dataset | 选择数据集的类型,可选:`human`, `cityscape`。 | +| dataset_config | 数据集配置的config | +| image_file | 待测试单张图片的路径,如果设置image_file,则dataset_config将无效。 | +| device | 预测时的设备,可选:`CPU`, `GPU`。 | +| use_trt | 是否使用 TesorRT 预测引擎,在device为```GPU```时生效。 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```,在device为```CPU```时生效。 | +| cpu_threads | CPU预测时,使用CPU线程数量,默认10 | +| precision | 预测时精度,可选:`fp32`, `fp16`, `int8`。 | -执行以下命令评估模型在测试集上的精度: - -``` -python eval.py \ ---model_dir ./pp_humanseg_qat \ ---model_filename model.pdmodel \ ---params_filename model.pdiparams \ ---dataset_config configs/dataset/humanseg_dataset.yaml - -``` -## 5.预测部署 +- TensorRT预测: -本小节以人像分割为例, 介绍如何使用Paddle Inference推理库执行压缩后的模型. +环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) -### 5.1 安装推理库 +准备好预测模型,并且修改dataset_config中数据集路径为正确的路径后,启动测试: -请参考该链接安装Python版本的PaddleInference推理库: [推理库安装教程](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python) - -### 5.2 准备模型和数据 - -从 [2.Benchmark](#2Benchmark) 的表格中获得压缩前后的推理模型的下载链接,执行以下命令下载并解压推理模型: - -下载Float32数值类型的模型: -``` -wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz -tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz -mv ppseg_lite_portrait_398x224_with_softmax pp_humanseg_fp32 -``` - -下载经过量化训练压缩后的推理模型: -``` -wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip -unzip pp_humanseg_qat.zip +```shell +python paddle_inference_eval.py \ + --model_path=pp_liteseg_qat \ + --dataset='cityscape' \ + --dataset_config=configs/dataset/cityscapes_1024x512_scale1.0.yml \ + --use_trt=True \ + --precision=int8 ``` -准备好需要处理的图片,这里直接使用人像示例图片 `./data/human_demo.jpg`。 - -### 5.3 执行推理 - -执行以下命令,直接使用飞桨框架的原生推理(仅支持Float32, 无需依赖TensorRT): +- MKLDNN预测: -``` -export CUDA_VISIBLE_DEVICES=0 -python infer.py \ ---image_file "./data/human_demo.jpg" \ ---model_path "./pp_humanseg_fp32/model.pdmodel" \ ---params_path "./pp_humanseg_fp32/model.pdiparams" \ ---save_file "./humanseg_result_fp32.png" \ ---dataset "human" \ ---benchmark True \ ---precision "fp32" +```shell +python paddle_inference_eval.py \ + --model_path=pp_liteseg_qat \ + --dataset='cityscape' \ + --dataset_config=configs/dataset/cityscapes_1024x512_scale1.0.yml \ + --device=CPU \ + --use_mkldnn=True \ + --precision=int8 \ + --cpu_threads=10 ``` -执行以下命令,使用Int8推理: +#### 4.2 Paddle Inference 测试单张图片 -``` -export CUDA_VISIBLE_DEVICES=0 -python infer.py \ ---image_file "./data/human_demo.jpg" \ ---model_path "./pp_humanseg_qat/model.pdmodel" \ ---params_path "./pp_humanseg_qat/model.pdiparams" \ ---save_file "./humanseg_result_qat.png" \ ---dataset "human" \ ---benchmark True \ ---use_trt True \ ---precision "int8" -``` +利用人像分割测试单张图片: -执行以下命令,使用Paddle Inference在相应数据集上测试精度: - -``` -export CUDA_VISIBLE_DEVICES=0 -python infer.py \ ---model_path "./pp_humanseg_qat/model.pdmodel" \ ---params_path "./pp_humanseg_qat/model.pdiparams" \ ---dataset_config configs/dataset/humanseg_dataset.yaml \ ---use_trt True \ ---precision "int8" +```shell +python paddle_inference_eval.py \ + --model_path=pp_humanseg_qat \ + --dataset='human' \ + --image_file=./data/human_demo.jpg \ + --use_trt=True \ + --precision=int8 ``` @@ -287,17 +250,11 @@ Int8推理结果
-执行以下命令查看更多关于 `infer.py` 使用说明: - -``` -python infer.py --help -``` - -### 5.4 更多部署教程 +### 4.3 更多部署教程 - [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md) - [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md) - [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md) -## 6.FAQ +## 5.FAQ diff --git a/example/auto_compression/semantic_segmentation/infer.py b/example/auto_compression/semantic_segmentation/paddle_inference_eval.py similarity index 53% rename from example/auto_compression/semantic_segmentation/infer.py rename to example/auto_compression/semantic_segmentation/paddle_inference_eval.py index f806b576..f9066389 100644 --- a/example/auto_compression/semantic_segmentation/infer.py +++ b/example/auto_compression/semantic_segmentation/paddle_inference_eval.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np import argparse import time -from tqdm import tqdm +import os +import sys +import cv2 +import numpy as np import paddle import paddleseg.transforms as T from paddleseg.cvlibs import Config as PaddleSegDataConfig @@ -38,62 +39,72 @@ def _transforms(dataset): return transforms -def auto_tune_trt(args, data): - auto_tuned_shape_file = "./auto_tuning_shape" - pred_cfg = PredictConfig(args.model_path, args.params_path) - pred_cfg.enable_use_gpu(100, 0) - pred_cfg.collect_shape_range_info("./auto_tuning_shape") - predictor = create_predictor(pred_cfg) - input_names = predictor.get_input_names() - input_handle = predictor.get_input_handle(input_names[0]) - input_handle.reshape(data.shape) - input_handle.copy_from_cpu(data) - predictor.run() - return auto_tuned_shape_file - - -def load_predictor(args, data): - pred_cfg = PredictConfig(args.model_path, args.params_path) - pred_cfg.disable_glog_info() +def load_predictor(args): + """ + load predictor func + """ + rerun_flag = False + model_file = os.path.join(args.model_path, args.model_filename) + params_file = os.path.join(args.model_path, args.params_filename) + pred_cfg = PredictConfig(model_file, params_file) pred_cfg.enable_memory_optim() pred_cfg.switch_ir_optim(True) if args.device == "GPU": pred_cfg.enable_use_gpu(100, 0) + else: + pred_cfg.disable_gpu() + pred_cfg.set_cpu_math_library_num_threads(args.cpu_threads) + if args.use_mkldnn: + pred_cfg.enable_mkldnn() + if args.precision == "int8": + pred_cfg.enable_mkldnn_int8({ + "conv2d", "depthwise_conv2d", "pool2d", "elementwise_mul" + }) if args.use_trt: # To collect the dynamic shapes of inputs for TensorRT engine - auto_tuned_shape_file = auto_tune_trt(args, data) - precision_map = { - "fp16": PrecisionType.Half, - "fp32": PrecisionType.Float32, - "int8": PrecisionType.Int8 - } - pred_cfg.enable_tensorrt_engine( - workspace_size=1 << 30, - max_batch_size=1, - min_subgraph_size=4, - precision_mode=precision_map[args.precision], - use_static=False, - use_calib_mode=False) - allow_build_at_runtime = True - pred_cfg.enable_tuned_tensorrt_dynamic_shape(auto_tuned_shape_file, - allow_build_at_runtime) + dynamic_shape_file = os.path.join(args.model_path, "dynamic_shape.txt") + if os.path.exists(dynamic_shape_file): + pred_cfg.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, + True) + print("trt set dynamic shape done!") + precision_map = { + "fp16": PrecisionType.Half, + "fp32": PrecisionType.Float32, + "int8": PrecisionType.Int8 + } + pred_cfg.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=4, + precision_mode=precision_map[args.precision], + use_static=True, + use_calib_mode=False, ) + else: + pred_cfg.disable_gpu() + pred_cfg.set_cpu_math_library_num_threads(10) + pred_cfg.collect_shape_range_info(dynamic_shape_file) + print("Start collect dynamic shape...") + rerun_flag = True + predictor = create_predictor(pred_cfg) - return predictor + return predictor, rerun_flag def predict_image(args): - + """ + predict image func + """ transforms = _transforms(args.dataset) transform = T.Compose(transforms) # Step1: Load image and preprocess - im = cv2.imread(args.image_file).astype('float32') + im = cv2.imread(args.image_file).astype("float32") data, _ = transform(im) data = np.array(data)[np.newaxis, :] # Step2: Prepare prdictor - predictor = load_predictor(args, data) + predictor, rerun_flag = load_predictor(args) # Step3: Inference input_names = predictor.get_input_names() @@ -114,14 +125,21 @@ def predict_image(args): for i in range(repeats): predictor.run() results = output_handle.copy_to_cpu() + if rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) + return total_time = time.time() - start_time avg_time = float(total_time) / repeats - print(f"Average inference time: \033[91m{round(avg_time*1000, 2)}ms\033[0m") + print( + f"[Benchmark]Average inference time: \033[91m{round(avg_time*1000, 2)}ms\033[0m" + ) # Step4: Post process if args.dataset == "human": results = reverse_transform( - paddle.to_tensor(results), im.shape, transforms, mode='bilinear') + paddle.to_tensor(results), im.shape, transforms, mode="bilinear") results = np.argmax(results, axis=1) result = get_pseudo_color_map(results[0]) @@ -132,8 +150,11 @@ def predict_image(args): def eval(args): + """ + eval mIoU func + """ # DataLoader need run on cpu - paddle.set_device('cpu') + paddle.set_device("cpu") data_cfg = PaddleSegDataConfig(args.dataset_config) eval_dataset = data_cfg.val_dataset @@ -142,48 +163,56 @@ def eval(args): loader = paddle.io.DataLoader( eval_dataset, batch_sampler=batch_sampler, - num_workers=1, + num_workers=0, return_list=True) - total_iters = len(loader) + predictor, rerun_flag = load_predictor(args) + intersect_area_all = 0 pred_area_all = 0 label_area_all = 0 - - print("Start evaluating (total_samples: {}, total_iters: {})...".format( - len(eval_dataset), total_iters)) - - init_predictor = False - for (image, label) in tqdm(loader): - label = np.array(label).astype('int64') + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + total_samples = len(eval_dataset) + sample_nums = len(loader) + batch_size = int(total_samples / sample_nums) + predict_time = 0.0 + time_min = float("inf") + time_max = float("-inf") + print("Start evaluating (total_samples: {}, total_iters: {}).".format( + total_samples, sample_nums)) + for batch_id, data in enumerate(loader): + image = np.array(data[0]) + label = np.array(data[1]).astype("int64") ori_shape = np.array(label).shape[-2:] - data = np.array(image) - - if not init_predictor: - predictor = load_predictor(args, data) - init_predictor = True - - input_names = predictor.get_input_names() - input_handle = predictor.get_input_handle(input_names[0]) - input_handle.reshape(data.shape) - input_handle.copy_from_cpu(data) - + input_handle.reshape(image.shape) + input_handle.copy_from_cpu(image) + start_time = time.time() predictor.run() - - output_names = predictor.get_output_names() - output_handle = predictor.get_output_handle(output_names[0]) results = output_handle.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + if rerun_flag: + print( + "***** Collect dynamic shape done, Please rerun the program to get correct results. *****" + ) + return logit = reverse_transform( paddle.to_tensor(results), ori_shape, eval_dataset.transforms.transforms, - mode='bilinear') + mode="bilinear") pred = paddle.to_tensor(logit) if len( pred.shape ) == 4: # for humanseg model whose prediction is distribution but not class id - pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32') + pred = paddle.argmax(pred, axis=1, keepdim=True, dtype="int32") intersect_area, pred_area, label_area = metrics.calculate_area( pred, @@ -193,71 +222,95 @@ def eval(args): intersect_area_all = intersect_area_all + intersect_area pred_area_all = pred_area_all + pred_area label_area_all = label_area_all + label_area + if batch_id % 100 == 0: + print("Eval iter:", batch_id) + sys.stdout.flush() - class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all, - label_area_all) - class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all) + _, miou = metrics.mean_iou(intersect_area_all, pred_area_all, + label_area_all) + _, acc = metrics.accuracy(intersect_area_all, pred_area_all) kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all) - class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all, - label_area_all) - - infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format( - len(eval_dataset), miou, acc, kappa, mdice) + _, mdice = metrics.dice(intersect_area_all, pred_area_all, label_area_all) + + time_avg = predict_time / sample_nums + print( + "[Benchmark]Batch size: {}, Inference time(ms): min={}, max={}, avg={}". + format(batch_size, + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + infor = "[Benchmark] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format( + total_samples, miou, acc, kappa, mdice) print(infor) + sys.stdout.flush() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--image_file', + "--model_path", type=str, help="inference model filepath") + parser.add_argument( + "--model_filename", + type=str, + default="model.pdmodel", + help="model file name") + parser.add_argument( + "--params_filename", + type=str, + default="model.pdiparams", + help="params file name") + parser.add_argument( + "--image_file", type=str, default=None, help="Image path to be processed.") parser.add_argument( - '--save_file', + "--save_file", type=str, default=None, help="The path to save the processed image.") parser.add_argument( - '--model_path', type=str, help="Inference model filepath.") - parser.add_argument( - '--params_path', type=str, help="Inference parameters filepath.") - parser.add_argument( - '--dataset', + "--dataset", type=str, default="human", choices=["human", "cityscape"], - help="The type of given image which can be 'human' or 'cityscape'.") + help="The type of given image which can be 'human' or 'cityscape'.", ) parser.add_argument( - '--dataset_config', + "--dataset_config", type=str, default=None, help="path of dataset config.") parser.add_argument( - '--benchmark', + "--benchmark", type=bool, default=False, help="Whether to run benchmark or not.") parser.add_argument( - '--use_trt', + "--use_trt", type=bool, default=False, help="Whether to use tensorrt engine or not.") parser.add_argument( - '--device', + "--device", type=str, - default='GPU', + default="GPU", choices=["CPU", "GPU"], - help="Choose the device you want to run, it can be: CPU/GPU, default is GPU" + help="Choose the device you want to run, it can be: CPU/GPU, default is GPU", ) parser.add_argument( - '--precision', + "--precision", type=str, - default='fp32', + default="fp32", choices=["fp32", "fp16", "int8"], - help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'." + help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'.", ) + parser.add_argument( + "--use_mkldnn", + type=bool, + default=False, + help="Whether use mkldnn or not.") + parser.add_argument( + "--cpu_threads", type=int, default=1, help="Num of cpu threads.") args = parser.parse_args() if args.image_file: predict_image(args) diff --git a/paddleslim/common/load_model.py b/paddleslim/common/load_model.py index f4f3be28..2f31f97d 100644 --- a/paddleslim/common/load_model.py +++ b/paddleslim/common/load_model.py @@ -230,7 +230,6 @@ def export_onnx(model_dir, opset_version=opset_version, enable_onnx_checker=True, deploy_backend=deploy_backend, - scale_file=os.path.join(model_dir, 'calibration_table.txt'), calibration_file=os.path.join( save_file_path.rstrip(os.path.split(save_file_path)[-1]), 'calibration.cache')) -- GitLab