From 6cee8749c083a852c53ad51817518fc8d65a9fc0 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 22 Aug 2022 13:40:56 +0800 Subject: [PATCH] update YOLO series TRT eval demo (#1374) --- .../pytorch_yolo_series/README.md | 46 ++- .../pytorch_yolo_series/TensorRT/README.md | 41 ++ .../TensorRT/trt_backend.py | 264 ++++++++++++ .../pytorch_yolo_series/TensorRT/trt_eval.py | 321 +++++++++++++++ .../pytorch_yolo_series/onnx_trt_infer.py | 378 ------------------ 5 files changed, 656 insertions(+), 394 deletions(-) create mode 100644 example/auto_compression/pytorch_yolo_series/TensorRT/README.md create mode 100644 example/auto_compression/pytorch_yolo_series/TensorRT/trt_backend.py create mode 100644 example/auto_compression/pytorch_yolo_series/TensorRT/trt_eval.py delete mode 100644 example/auto_compression/pytorch_yolo_series/onnx_trt_infer.py diff --git a/example/auto_compression/pytorch_yolo_series/README.md b/example/auto_compression/pytorch_yolo_series/README.md index fbbd79d4..052daf27 100644 --- a/example/auto_compression/pytorch_yolo_series/README.md +++ b/example/auto_compression/pytorch_yolo_series/README.md @@ -18,23 +18,23 @@ ## 2.Benchmark -| 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 模型体积 | 预测时延FP32
|预测时延FP16
| 预测时延INT8
| 配置文件 | Inference模型 | -| :-------- |:-------- |:--------: | :--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | -| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.onnx) | +| 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 模型体积 | 预测时延FP32
|预测时延FP16
| 预测时延INT8
| 内存占用 | 显存占用 | 配置文件 | Inference模型 | +| :-------- |:-------- |:--------: | :--------: | :---------------------: | :----------------: | :----------------: |:----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | +| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | 1718MB | 705MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | +| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | +| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | 736MB | 315MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) | | | | | | | | | | | -| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | -| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.onnx) | +| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | 1208MB | 555MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | +| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | +| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | 736MB | 315MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) | | | | | | | | | | | -| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | -| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) | +| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | 1722MB | 917MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | +| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | 827MB | 363MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | +| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | 827MB | 363MB | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant_onnx.tar) | | | | | | | | | | | -| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) | -| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | - | - | -| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.onnx) | +| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | 738MB | 349MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) | +| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | 729MB | 315MB | - | - | +| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | 729MB | 315MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) | 说明: - mAP的指标均在COCO val2017数据集中评测得到。 @@ -136,13 +136,27 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml ## 4.预测部署 +执行完自动压缩后会生成: +```shell +├── model.pdiparams # Paddle预测模型权重 +├── model.pdmodel # Paddle预测模型文件 +├── calibration_table.txt # Paddle量化后校准表 +├── ONNX +│ ├── quant_model.onnx # 量化后转出的ONNX模型 +│ ├── calibration.cache # TensorRT可以直接加载的校准表 +``` + #### 导出至ONNX使用TensorRT部署 -执行完自动压缩后会默认在`save_dir`中生成`quant_model.onnx`的ONNX模型文件,可以直接使用TensorRT测试脚本进行验证。 +加载`quant_model.onnx`和`calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[./TensorRT] - 进行测试: ```shell -python yolov7_onnx_trt.py --model_path=output/quant_model.onnx --image_file=images/000000570688.jpg --precision=int8 +cd TensorRT +python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ + --calibration_file=output/ONNX/calibration.cache \ + --image_file=../images/000000570688.jpg \ + --precision_mode=int8 ``` #### Paddle-TensorRT部署 diff --git a/example/auto_compression/pytorch_yolo_series/TensorRT/README.md b/example/auto_compression/pytorch_yolo_series/TensorRT/README.md new file mode 100644 index 00000000..a755cb4d --- /dev/null +++ b/example/auto_compression/pytorch_yolo_series/TensorRT/README.md @@ -0,0 +1,41 @@ +# TensorRT Python预测 + +### 验证COCO mAP + +-FP16 +```shell +python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7-tiny.onnx \ + --precision_mode=fp16 \ + --dataset_dir=dataset/coco/ \ + --val_image_dir=val2017 \ + --val_anno_path=annotations/instances_val2017.json +``` + +- INT8 +```shell +python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7_tiny_quant.onnx \ + --calibration_file=yolov7_tiny_quant_onnx/calibration.cache \ + --precision_mode=int8 \ + --dataset_dir=dataset/coco/ \ + --val_image_dir=val2017 \ + --val_anno_path=annotations/instances_val2017.json +``` + +### 验证单张图片 + +- FP16 +```shell +python trt_eval.py --onnx_model_file=yolov7-tiny.onnx --image_file=../images/000000570688.jpg --precision_mode=fp16 +``` + +- INT8 +```shell +python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7_tiny_quant.onnx \ + --calibration_file=yolov7_tiny_quant_onnx/calibration.cache \ + --image_file=../images/000000570688.jpg \ + --precision_mode=int8 +``` + +### FAQ + +- 测试内存和显存占用时,首次运行会将ONNX模型转换成TRT模型,耗时不准确,再此运行trt_eval.py可获取真实的内存和显存占用。 diff --git a/example/auto_compression/pytorch_yolo_series/TensorRT/trt_backend.py b/example/auto_compression/pytorch_yolo_series/TensorRT/trt_backend.py new file mode 100644 index 00000000..614884cf --- /dev/null +++ b/example/auto_compression/pytorch_yolo_series/TensorRT/trt_backend.py @@ -0,0 +1,264 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit +import sys +import os +import copy +import numpy as np + +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) +EXPLICIT_PRECISION = 1 << ( + int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) + + +class LoadCalibrator(trt.IInt8EntropyCalibrator2): + def __init__(self, cache_file="calibration.cache"): + super().__init__() + self.cache_file = cache_file + + def get_batch_size(self): + return 1 + + def read_calibration_cache(self): + # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + if os.path.exists(self.cache_file): + with open(self.cache_file, "rb") as f: + print("Using calibration cache to save time: {:}".format( + self.cache_file)) + return f.read() + + +def remove_initializer_from_input(ori_model): + model = copy.deepcopy(ori_model) + if model.ir_version < 4: + print( + 'Model with ir_version below 4 requires to include initilizer in graph input' + ) + return + + inputs = model.graph.input + name_to_input = {} + for input in inputs: + name_to_input[input.name] = input + + for initializer in model.graph.initializer: + if initializer.name in name_to_input: + inputs.remove(name_to_input[initializer.name]) + return model + + +# Simple helper data class that's a little nicer to use than a 2-tuple. +class HostDeviceMem(object): + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + if host_mem: + self.nbytes = host_mem.nbytes + else: + self.nbytes = 0 + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + + +class TrtEngine: + def __init__(self, + onnx_model_file, + shape_info=None, + max_batch_size=None, + precision_mode="fp32", + engine_file_path=None, + calibration_cache_file="calibration.cache", + verbose=False): + self.max_batch_size = 1 if max_batch_size is None else max_batch_size + precision_mode = precision_mode.lower() + assert precision_mode in [ + "fp32", "fp16", "int8" + ], "precision_mode must be fp32, fp16 or int8, but your precision_mode is: {}".format( + precision_mode) + use_int8 = precision_mode == "int8" + use_fp16 = precision_mode == "fp16" + TRT_LOGGER = trt.Logger() + if verbose: + TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) + if engine_file_path is not None and os.path.exists(engine_file_path): + # If a serialized engine exists, use it instead of building an engine. + print("[TRT Backend] Reading engine from file {}".format( + engine_file_path)) + with open(engine_file_path, + "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: + self.engine = runtime.deserialize_cuda_engine(f.read()) + else: + builder = trt.Builder(TRT_LOGGER) + config = builder.create_builder_config() + network = None + + if use_int8 and not builder.platform_has_fast_int8: + print("[TRT Backend] INT8 not supported on this platform.") + if use_fp16 and not builder.platform_has_fast_fp16: + print("[TRT Backend] FP16 not supported on this platform.") + + if use_int8 and builder.platform_has_fast_int8: + print("[TRT Backend] Use INT8.") + network = builder.create_network(EXPLICIT_BATCH | + EXPLICIT_PRECISION) + config.int8_calibrator = LoadCalibrator(calibration_cache_file) + config.set_flag(trt.BuilderFlag.INT8) + elif use_fp16 and builder.platform_has_fast_fp16: + print("[TRT Backend] Use FP16.") + network = builder.create_network(EXPLICIT_BATCH) + config.set_flag(trt.BuilderFlag.FP16) + else: + print("[TRT Backend] Use FP32.") + network = builder.create_network(EXPLICIT_BATCH) + parser = trt.OnnxParser(network, TRT_LOGGER) + runtime = trt.Runtime(TRT_LOGGER) + config.max_workspace_size = 1 << 28 + + import onnx + print("[TRT Backend] Loading ONNX model ...") + onnx_model = onnx_model_file + if not isinstance(onnx_model_file, onnx.ModelProto): + onnx_model = onnx.load(onnx_model_file) + onnx_model = remove_initializer_from_input(onnx_model) + if not parser.parse(onnx_model.SerializeToString()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise Exception("ERROR: Failed to parse the ONNX file.") + + if shape_info is None: + builder.max_batch_size = 1 + for i in range(len(onnx_model.graph.input)): + input_shape = [ + x.dim_value + for x in onnx_model.graph.input[0] + .type.tensor_type.shape.dim + ] + for s in input_shape: + assert s > 0, "In static shape mode, the input of onnx model should be fixed, but now it's {}".format( + onnx_model.graph.input[i]) + else: + max_batch_size = 1 + if shape_info is not None: + assert len( + shape_info + ) == network.num_inputs, "Length of shape_info: {} is not same with length of model input: {}".format( + len(shape_info), network.num_inputs) + profile = builder.create_optimization_profile() + for k, v in shape_info.items(): + if v[2][0] > max_batch_size: + max_batch_size = v[2][0] + print("[TRT Backend] optimize shape: ", k, v[0], v[1], + v[2]) + profile.set_shape(k, v[0], v[1], v[2]) + config.add_optimization_profile(profile) + if max_batch_size > self.max_batch_size: + self.max_batch_size = max_batch_size + builder.max_batch_size = self.max_batch_size + + print("[TRT Backend] Completed parsing of ONNX file.") + print( + "[TRT Backend] Building an engine from onnx model may take a while..." + ) + plan = builder.build_serialized_network(network, config) + print("[TRT Backend] Start Creating Engine.") + self.engine = runtime.deserialize_cuda_engine(plan) + print("[TRT Backend] Completed Creating Engine.") + if engine_file_path is not None: + with open(engine_file_path, "wb") as f: + f.write(self.engine.serialize()) + + self.context = self.engine.create_execution_context() + if shape_info is not None: + self.context.active_optimization_profile = 0 + self.stream = cuda.Stream() + self.bindings = [] + self.inputs = [] + self.outputs = [] + for binding in self.engine: + self.bindings.append(0) + if self.engine.binding_is_input(binding): + self.inputs.append(HostDeviceMem(None, None)) + else: + self.outputs.append(HostDeviceMem(None, None)) + + print("[TRT Backend] Completed TrtEngine init ...") + + def infer(self, input_data): + assert len(self.inputs) == len( + input_data + ), "Length of input_data: {} is not same with length of input: {}".format( + len(input_data), len(self.inputs)) + + self.allocate_buffers(input_data) + + return self.do_inference_v2( + self.context, + bindings=self.bindings, + inputs=self.inputs, + outputs=self.outputs, + stream=self.stream) + + def do_inference_v2(self, context, bindings, inputs, outputs, stream): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [ + cuda.memcpy_dtoh_async(out.host, out.device, stream) + for out in outputs + ] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] + + def allocate_buffers(self, input_data): + input_idx = 0 + output_idx = 0 + for binding in self.engine: + idx = self.engine.get_binding_index(binding) + if self.engine.binding_is_input(binding): + if not input_data[input_idx].flags['C_CONTIGUOUS']: + input_data[input_idx] = np.ascontiguousarray(input_data[ + input_idx]) + self.context.set_binding_shape(idx, + (input_data[input_idx].shape)) + self.inputs[input_idx].host = input_data[input_idx] + nbytes = input_data[input_idx].nbytes + if self.inputs[input_idx].nbytes < nbytes: + self.inputs[input_idx].nbytes = nbytes + self.inputs[input_idx].device = cuda.mem_alloc(nbytes) + self.bindings[idx] = int(self.inputs[input_idx].device) + input_idx += 1 + else: + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + shape = self.context.get_binding_shape(idx) + self.outputs[output_idx].host = np.ascontiguousarray( + np.empty( + shape, dtype=dtype)) + nbytes = self.outputs[output_idx].host.nbytes + if self.outputs[output_idx].nbytes < nbytes: + self.outputs[output_idx].nbytes = nbytes + self.outputs[output_idx].device = cuda.mem_alloc( + self.outputs[output_idx].host.nbytes) + self.bindings[idx] = int(self.outputs[output_idx].device) + output_idx += 1 diff --git a/example/auto_compression/pytorch_yolo_series/TensorRT/trt_eval.py b/example/auto_compression/pytorch_yolo_series/TensorRT/trt_eval.py new file mode 100644 index 00000000..cacbe7b4 --- /dev/null +++ b/example/auto_compression/pytorch_yolo_series/TensorRT/trt_eval.py @@ -0,0 +1,321 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +warnings.filterwarnings("ignore") +import os +import sys +import numpy as np +import argparse +from tqdm import tqdm +import pkg_resources as pkg +import time +import cv2 + +import paddle +import onnx + +sys.path.append("../") +from post_process import YOLOPostProcess, coco_metric +from dataset import COCOValDataset +import trt_backend + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--onnx_model_file', + type=str, + default='yolov7_tiny_quant_onnx/quant_model.onnx', + help="onnx model file path.") + parser.add_argument( + '--calibration_file', + type=str, + default='yolov7_tiny_quant_onnx/calibration.cache', + help="quant onnx model calibration cache file.") + parser.add_argument( + '--image_file', + type=str, + default=None, + help="image path, if set image_file, it will not eval coco.") + parser.add_argument( + '--dataset_dir', + type=str, + default='dataset/coco', + help="COCO dataset dir.") + parser.add_argument( + '--val_image_dir', + type=str, + default='val2017', + help="COCO dataset val image dir.") + parser.add_argument( + '--val_anno_path', + type=str, + default='annotations/instances_val2017.json', + help="COCO dataset anno path.") + parser.add_argument( + '--precision_mode', + type=str, + default='fp32', + help="support fp32/fp16/int8.") + parser.add_argument( + '--batch_size', type=int, default=1, help="Batch size of model input.") + + return parser + + +# load coco labels +CLASS_LABEL = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", + "truck", "boat", "traffic light", "fire hydrant", "stop sign", + "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", + "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", + "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", + "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", + "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", + "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", + "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", + "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush" +] + + +def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)): + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, ).astype(np.float32) + padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img + + padded_img = padded_img[:, :, ::-1] + padded_img /= 255.0 + if mean is not None: + padded_img -= mean + if std is not None: + padded_img /= std + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + +def get_color_map_list(num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None): + color_list = get_color_map_list(len(class_names)) + for i in range(len(boxes)): + box = boxes[i] + cls_id = int(cls_ids[i]) + color = tuple(color_list[cls_id]) + score = scores[i] + if score < conf: + continue + x0 = int(box[0]) + y0 = int(box[1]) + x1 = int(box[2]) + y1 = int(box[3]) + + text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100) + font = cv2.FONT_HERSHEY_SIMPLEX + + txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] + cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) + cv2.rectangle(img, (x0, y0 + 1), + (x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])), + color, -1) + cv2.putText( + img, + text, (x0, y0 + txt_size[1]), + font, + 0.8, (0, 255, 0), + thickness=2) + + return img + + +def get_current_memory_mb(): + """ + It is used to Obtain the memory usage of the CPU and GPU during the running of the program. + And this function Current program is time-consuming. + """ + try: + pkg.require('pynvml') + except: + from pip._internal import main + main(['install', 'pynvml']) + try: + pkg.require('psutil') + except: + from pip._internal import main + main(['install', 'psutil']) + try: + pkg.require('GPUtil') + except: + from pip._internal import main + main(['install', 'GPUtil']) + import pynvml + import psutil + import GPUtil + gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0)) + + pid = os.getpid() + p = psutil.Process(pid) + info = p.memory_full_info() + cpu_mem = info.uss / 1024. / 1024. + gpu_mem = 0 + gpu_percent = 0 + gpus = GPUtil.getGPUs() + if gpu_id is not None and len(gpus) > 0: + gpu_percent = gpus[gpu_id].load + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) + gpu_mem = meminfo.used / 1024. / 1024. + return round(cpu_mem, 4), round(gpu_mem, 4) + + +def load_trt_engine(): + model = onnx.load(FLAGS.onnx_model_file) + model_name = os.path.split(FLAGS.onnx_model_file)[-1].rstrip('.onnx') + if FLAGS.precision_mode == "int8": + engine_file = "{}_quant_model.trt".format(model_name) + assert os.path.exists(FLAGS.calibration_file) + trt_engine = trt_backend.TrtEngine( + model, + max_batch_size=1, + precision_mode=FLAGS.precision_mode, + engine_file_path=engine_file, + calibration_cache_file=FLAGS.calibration_file) + else: + engine_file = "{}_{}_model.trt".format(model_name, FLAGS.precision_mode) + trt_engine = trt_backend.TrtEngine( + model, + max_batch_size=1, + precision_mode=FLAGS.precision_mode, + engine_file_path=engine_file) + return trt_engine + + +def eval(): + trt_engine = load_trt_engine() + bboxes_list, bbox_nums_list, image_id_list = [], [], [] + cpu_mems, gpu_mems = 0, 0 + sample_nums = len(val_loader) + with tqdm( + total=sample_nums, + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for data in val_loader: + data_all = {k: np.array(v) for k, v in data.items()} + outs = trt_engine.infer([data_all['image']]) + outs = np.array(outs).reshape(1, -1, 85) + postprocess = YOLOPostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs), data_all['scale_factor']) + bboxes_list.append(res['bbox']) + bbox_nums_list.append(res['bbox_num']) + image_id_list.append(np.array(data_all['im_id'])) + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + t.update() + print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format( + cpu_mems / sample_nums, gpu_mems / sample_nums)) + + coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list) + + +def infer(): + origin_img = cv2.imread(FLAGS.image_file) + input_shape = [640, 640] + input_image, scale_factor = preprocess(origin_img, input_shape) + input_image = np.expand_dims(input_image, axis=0) + scale_factor = np.array([[scale_factor, scale_factor]]) + trt_engine = load_trt_engine() + + repeat = 100 + cpu_mems, gpu_mems = 0, 0 + for _ in range(0, repeat): + outs = trt_engine.infer(input_image) + cpu_mem, gpu_mem = get_current_memory_mb() + cpu_mems += cpu_mem + gpu_mems += gpu_mem + print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format(cpu_mems / repeat, + gpu_mems / repeat)) + # Do postprocess + outs = np.array(outs).reshape(1, -1, 85) + postprocess = YOLOPostProcess( + score_threshold=0.1, nms_threshold=0.45, multi_label=False) + res = postprocess(np.array(outs), scale_factor) + # Draw rectangles and labels on the original image + dets = res['bbox'] + if dets is not None: + final_boxes, final_scores, final_class = dets[:, 2:], dets[:, + 1], dets[:, + 0] + origin_img = draw_box( + origin_img, + final_boxes, + final_scores, + final_class, + conf=0.5, + class_names=CLASS_LABEL) + cv2.imwrite('output.jpg', origin_img) + print('The prediction results are saved in output.jpg.') + + +def main(): + if FLAGS.image_file: + infer() + else: + global val_loader + dataset = COCOValDataset( + dataset_dir=FLAGS.dataset_dir, + image_dir=FLAGS.val_image_dir, + anno_path=FLAGS.val_anno_path) + global anno_file + anno_file = dataset.ann_file + val_loader = paddle.io.DataLoader( + dataset, batch_size=FLAGS.batch_size, drop_last=True) + eval() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + paddle.set_device('cpu') + + main() diff --git a/example/auto_compression/pytorch_yolo_series/onnx_trt_infer.py b/example/auto_compression/pytorch_yolo_series/onnx_trt_infer.py deleted file mode 100644 index 3540c33d..00000000 --- a/example/auto_compression/pytorch_yolo_series/onnx_trt_infer.py +++ /dev/null @@ -1,378 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import cv2 -import tensorrt as trt -import pycuda.driver as cuda -import pycuda.autoinit -import os -import time -import random -import argparse - -EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) -EXPLICIT_PRECISION = 1 << ( - int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION) - -# load coco labels -CLASS_LABEL = [ - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", - "truck", "boat", "traffic light", "fire hydrant", "stop sign", - "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", - "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", - "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", - "baseball bat", "baseball glove", "skateboard", "surfboard", - "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", - "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", - "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", - "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", - "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", - "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", - "hair drier", "toothbrush" -] - - -def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)): - if len(image.shape) == 3: - padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 - else: - padded_img = np.ones(input_size) * 114.0 - img = np.array(image) - r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) - resized_img = cv2.resize( - img, - (int(img.shape[1] * r), int(img.shape[0] * r)), - interpolation=cv2.INTER_LINEAR, ).astype(np.float32) - padded_img[:int(img.shape[0] * r), :int(img.shape[1] * r)] = resized_img - - padded_img = padded_img[:, :, ::-1] - padded_img /= 255.0 - if mean is not None: - padded_img -= mean - if std is not None: - padded_img /= std - padded_img = padded_img.transpose(swap) - padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) - return padded_img, r - - -def postprocess(predictions, ratio): - boxes = predictions[:, :4] - scores = predictions[:, 4:5] * predictions[:, 5:] - boxes_xyxy = np.ones_like(boxes) - boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. - boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. - boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. - boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. - boxes_xyxy /= ratio - dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) - return dets - - -def nms(boxes, scores, nms_thr): - """Single class NMS implemented in Numpy.""" - x1 = boxes[:, 0] - y1 = boxes[:, 1] - x2 = boxes[:, 2] - y2 = boxes[:, 3] - - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= nms_thr)[0] - order = order[inds + 1] - - return keep - - -def multiclass_nms(boxes, scores, nms_thr, score_thr): - """Multiclass NMS implemented in Numpy""" - final_dets = [] - num_classes = scores.shape[1] - for cls_ind in range(num_classes): - cls_scores = scores[:, cls_ind] - valid_score_mask = cls_scores > score_thr - if valid_score_mask.sum() == 0: - continue - else: - valid_scores = cls_scores[valid_score_mask] - valid_boxes = boxes[valid_score_mask] - keep = nms(valid_boxes, valid_scores, nms_thr) - if len(keep) > 0: - cls_inds = np.ones((len(keep), 1)) * cls_ind - dets = np.concatenate( - [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1) - final_dets.append(dets) - if len(final_dets) == 0: - return None - return np.concatenate(final_dets, 0) - - -def get_color_map_list(num_classes): - color_map = num_classes * [0, 0, 0] - for i in range(0, num_classes): - j = 0 - lab = i - while lab: - color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) - color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) - color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) - j += 1 - lab >>= 3 - color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] - return color_map - - -def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None): - color_list = get_color_map_list(len(class_names)) - for i in range(len(boxes)): - box = boxes[i] - cls_id = int(cls_ids[i]) - color = tuple(color_list[cls_id]) - score = scores[i] - if score < conf: - continue - x0 = int(box[0]) - y0 = int(box[1]) - x1 = int(box[2]) - y1 = int(box[3]) - - text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100) - font = cv2.FONT_HERSHEY_SIMPLEX - - txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] - cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) - cv2.rectangle(img, (x0, y0 + 1), - (x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])), - color, -1) - cv2.putText( - img, - text, (x0, y0 + txt_size[1]), - font, - 0.8, (0, 255, 0), - thickness=2) - - return img - - -def get_engine(precision, model_file_path): - # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) - TRT_LOGGER = trt.Logger() - builder = trt.Builder(TRT_LOGGER) - config = builder.create_builder_config() - if precision == 'int8': - network = builder.create_network(EXPLICIT_BATCH | EXPLICIT_PRECISION) - else: - network = builder.create_network(EXPLICIT_BATCH) - parser = trt.OnnxParser(network, TRT_LOGGER) - - runtime = trt.Runtime(TRT_LOGGER) - if model_file_path.endswith('.trt'): - # If a serialized engine exists, use it instead of building an engine. - print("Reading engine from file {}".format(model_file_path)) - with open(model_file_path, - "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: - engine = runtime.deserialize_cuda_engine(f.read()) - for i in range(network.num_layers): - layer = network.get_layer(i) - print(i, layer.name) - return engine - else: - config.max_workspace_size = 1 << 30 - - if precision == "fp16": - if not builder.platform_has_fast_fp16: - print("FP16 is not supported natively on this platform/device") - else: - config.set_flag(trt.BuilderFlag.FP16) - elif precision == "int8": - if not builder.platform_has_fast_int8: - print("INT8 is not supported natively on this platform/device") - else: - if builder.platform_has_fast_fp16: - # Also enable fp16, as some layers may be even more efficient in fp16 than int8 - config.set_flag(trt.BuilderFlag.FP16) - config.set_flag(trt.BuilderFlag.INT8) - - builder.max_batch_size = 1 - print('Loading ONNX file from path {}...'.format(model_file_path)) - with open(model_file_path, 'rb') as model: - print('Beginning ONNX file parsing') - if not parser.parse(model.read()): - print('ERROR: Failed to parse the ONNX file.') - for error in range(parser.num_errors): - print(parser.get_error(error)) - return None - - print('Completed parsing of ONNX file') - print('Building an engine from file {}; this may take a while...'. - format(model_file_path)) - plan = builder.build_serialized_network(network, config) - engine = runtime.deserialize_cuda_engine(plan) - print("Completed creating Engine") - with open(model_file_path, "wb") as f: - f.write(engine.serialize()) - for i in range(network.num_layers): - layer = network.get_layer(i) - print(i, layer.name) - return engine - - -# Simple helper data class that's a little nicer to use than a 2-tuple. -class HostDeviceMem(object): - def __init__(self, host_mem, device_mem): - self.host = host_mem - self.device = device_mem - - def __str__(self): - return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) - - def __repr__(self): - return self.__str__() - - -def allocate_buffers(engine): - inputs = [] - outputs = [] - bindings = [] - stream = cuda.Stream() - for binding in engine: - size = trt.volume(engine.get_binding_shape( - binding)) * engine.max_batch_size - dtype = trt.nptype(engine.get_binding_dtype(binding)) - # Allocate host and device buffers - host_mem = cuda.pagelocked_empty(size, dtype) - device_mem = cuda.mem_alloc(host_mem.nbytes) - # Append the device buffer to device bindings. - bindings.append(int(device_mem)) - # Append to the appropriate list. - if engine.binding_is_input(binding): - inputs.append(HostDeviceMem(host_mem, device_mem)) - else: - outputs.append(HostDeviceMem(host_mem, device_mem)) - return inputs, outputs, bindings, stream - - -def run_inference(context, bindings, inputs, outputs, stream): - # Transfer input data to the GPU. - [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] - # Run inference. - context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) - # Transfer predictions back from the GPU. - [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] - # Synchronize the stream - stream.synchronize() - # Return only the host outputs. - return [out.host for out in outputs] - - -def main(args): - onnx_model = args.model_path - img_path = args.image_file - num_class = len(CLASS_LABEL) - repeat = 1000 - engine = get_engine(args.precision, onnx_model) - - model_all_names = [] - for idx in range(engine.num_bindings): - is_input = engine.binding_is_input(idx) - name = engine.get_binding_name(idx) - op_type = engine.get_binding_dtype(idx) - model_all_names.append(name) - shape = engine.get_binding_shape(idx) - print('input id:', idx, ' is input: ', is_input, ' binding name:', - name, ' shape:', shape, 'type: ', op_type) - - context = engine.create_execution_context() - print('Allocate buffers ...') - inputs, outputs, bindings, stream = allocate_buffers(engine) - print("TRT set input ...") - - origin_img = cv2.imread(img_path) - input_shape = [args.img_shape, args.img_shape] - input_image, ratio = preprocess(origin_img, input_shape) - - inputs[0].host = np.expand_dims(input_image, axis=0) - - for _ in range(0, 50): - trt_outputs = run_inference( - context, - bindings=bindings, - inputs=inputs, - outputs=outputs, - stream=stream) - - time1 = time.time() - for _ in range(0, repeat): - trt_outputs = run_inference( - context, - bindings=bindings, - inputs=inputs, - outputs=outputs, - stream=stream) - time2 = time.time() - # total time cost(ms) - total_inference_cost = (time2 - time1) * 1000 - print("model path: ", onnx_model, " precision: ", args.precision) - print("In TensorRT, ", - "average latency is : {} ms".format(total_inference_cost / repeat)) - # Do postprocess - output = trt_outputs[0] - predictions = np.reshape(output, (1, -1, int(5 + num_class)))[0] - dets = postprocess(predictions, ratio) - # Draw rectangles and labels on the original image - if dets is not None: - final_boxes, final_scores, final_cls_inds = dets[:, : - 4], dets[:, 4], dets[:, - 5] - origin_img = draw_box( - origin_img, - final_boxes, - final_scores, - final_cls_inds, - conf=0.5, - class_names=CLASS_LABEL) - cv2.imwrite('output.jpg', origin_img) - print('The prediction results are saved in output.jpg.') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - '--model_path', - type=str, - default="quant_model.onnx", - help="inference model filepath") - parser.add_argument( - '--image_file', type=str, default="bus.jpg", help="image path") - parser.add_argument( - '--precision', type=str, default='fp32', help="support fp32/fp16/int8.") - parser.add_argument('--img_shape', type=int, default=640, help="input_size") - args = parser.parse_args() - main(args) -- GitLab