未验证 提交 6cee8749 编写于 作者: G Guanghua Yu 提交者: GitHub

update YOLO series TRT eval demo (#1374)

上级 53cc3430
...@@ -18,23 +18,23 @@ ...@@ -18,23 +18,23 @@
## 2.Benchmark ## 2.Benchmark
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 模型体积 | 预测时延<sup><small>FP32</small><sup><br><sup> |预测时延<sup><small>FP16</small><sup><br><sup> | 预测时延<sup><small>INT8</small><sup><br><sup> | 配置文件 | Inference模型 | | 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 模型体积 | 预测时延<sup><small>FP32</small><sup><br><sup> |预测时延<sup><small>FP16</small><sup><br><sup> | 预测时延<sup><small>INT8</small><sup><br><sup> | 内存占用 | 显存占用 | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | | :-------- |:-------- |:--------: | :--------: | :---------------------: | :----------------: | :----------------: |:----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | | YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 5.95ms | 2.44ms | - | 1718MB | 705MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) |
| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | | YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 1.87ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.onnx) | | YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **1.87ms** | 736MB | 315MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) |
| | | | | | | | | | | | | | | | | | | |
| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | | YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.06ms | 2.90ms | - | 1208MB | 555MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) |
| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | | YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 1.83ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.onnx) | | YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **1.83ms** | 736MB | 315MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) |
| | | | | | | | | | | | | | | | | | | |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | | YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.84ms | 7.44ms | - | 1722MB | 917MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) |
| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | | YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 4.55ms | 827MB | 363MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - |
| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.onnx) | | YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **4.55ms** | 827MB | 363MB | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant_onnx.tar) |
| | | | | | | | | | | | | | | | | | | |
| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) | | YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | 738MB | 349MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) |
| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | - | - | | YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | 729MB | 315MB | - | - |
| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.onnx) | | YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | 729MB | 315MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) &#124; [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) |
说明: 说明:
- mAP的指标均在COCO val2017数据集中评测得到。 - mAP的指标均在COCO val2017数据集中评测得到。
...@@ -136,13 +136,27 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml ...@@ -136,13 +136,27 @@ python eval.py --config_path=./configs/yolov7_tiny_qat_dis.yaml
## 4.预测部署 ## 4.预测部署
执行完自动压缩后会生成:
```shell
├── model.pdiparams # Paddle预测模型权重
├── model.pdmodel # Paddle预测模型文件
├── calibration_table.txt # Paddle量化后校准表
├── ONNX
│ ├── quant_model.onnx # 量化后转出的ONNX模型
│ ├── calibration.cache # TensorRT可以直接加载的校准表
```
#### 导出至ONNX使用TensorRT部署 #### 导出至ONNX使用TensorRT部署
执行完自动压缩后会默认在`save_dir`中生成`quant_model.onnx`的ONNX模型文件,可以直接使用TensorRT测试脚本进行验证。 加载`quant_model.onnx``calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[./TensorRT]
- 进行测试: - 进行测试:
```shell ```shell
python yolov7_onnx_trt.py --model_path=output/quant_model.onnx --image_file=images/000000570688.jpg --precision=int8 cd TensorRT
python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \
--calibration_file=output/ONNX/calibration.cache \
--image_file=../images/000000570688.jpg \
--precision_mode=int8
``` ```
#### Paddle-TensorRT部署 #### Paddle-TensorRT部署
......
# TensorRT Python预测
### 验证COCO mAP
-FP16
```shell
python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7-tiny.onnx \
--precision_mode=fp16 \
--dataset_dir=dataset/coco/ \
--val_image_dir=val2017 \
--val_anno_path=annotations/instances_val2017.json
```
- INT8
```shell
python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7_tiny_quant.onnx \
--calibration_file=yolov7_tiny_quant_onnx/calibration.cache \
--precision_mode=int8 \
--dataset_dir=dataset/coco/ \
--val_image_dir=val2017 \
--val_anno_path=annotations/instances_val2017.json
```
### 验证单张图片
- FP16
```shell
python trt_eval.py --onnx_model_file=yolov7-tiny.onnx --image_file=../images/000000570688.jpg --precision_mode=fp16
```
- INT8
```shell
python trt_eval.py --onnx_model_file=yolov7_tiny_quant_onnx/yolov7_tiny_quant.onnx \
--calibration_file=yolov7_tiny_quant_onnx/calibration.cache \
--image_file=../images/000000570688.jpg \
--precision_mode=int8
```
### FAQ
- 测试内存和显存占用时,首次运行会将ONNX模型转换成TRT模型,耗时不准确,再此运行trt_eval.py可获取真实的内存和显存占用。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import sys
import os
import copy
import numpy as np
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
EXPLICIT_PRECISION = 1 << (
int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
class LoadCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, cache_file="calibration.cache"):
super().__init__()
self.cache_file = cache_file
def get_batch_size(self):
return 1
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
print("Using calibration cache to save time: {:}".format(
self.cache_file))
return f.read()
def remove_initializer_from_input(ori_model):
model = copy.deepcopy(ori_model)
if model.ir_version < 4:
print(
'Model with ir_version below 4 requires to include initilizer in graph input'
)
return
inputs = model.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input
for initializer in model.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])
return model
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
if host_mem:
self.nbytes = host_mem.nbytes
else:
self.nbytes = 0
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
class TrtEngine:
def __init__(self,
onnx_model_file,
shape_info=None,
max_batch_size=None,
precision_mode="fp32",
engine_file_path=None,
calibration_cache_file="calibration.cache",
verbose=False):
self.max_batch_size = 1 if max_batch_size is None else max_batch_size
precision_mode = precision_mode.lower()
assert precision_mode in [
"fp32", "fp16", "int8"
], "precision_mode must be fp32, fp16 or int8, but your precision_mode is: {}".format(
precision_mode)
use_int8 = precision_mode == "int8"
use_fp16 = precision_mode == "fp16"
TRT_LOGGER = trt.Logger()
if verbose:
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
if engine_file_path is not None and os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("[TRT Backend] Reading engine from file {}".format(
engine_file_path))
with open(engine_file_path,
"rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
else:
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
network = None
if use_int8 and not builder.platform_has_fast_int8:
print("[TRT Backend] INT8 not supported on this platform.")
if use_fp16 and not builder.platform_has_fast_fp16:
print("[TRT Backend] FP16 not supported on this platform.")
if use_int8 and builder.platform_has_fast_int8:
print("[TRT Backend] Use INT8.")
network = builder.create_network(EXPLICIT_BATCH |
EXPLICIT_PRECISION)
config.int8_calibrator = LoadCalibrator(calibration_cache_file)
config.set_flag(trt.BuilderFlag.INT8)
elif use_fp16 and builder.platform_has_fast_fp16:
print("[TRT Backend] Use FP16.")
network = builder.create_network(EXPLICIT_BATCH)
config.set_flag(trt.BuilderFlag.FP16)
else:
print("[TRT Backend] Use FP32.")
network = builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
runtime = trt.Runtime(TRT_LOGGER)
config.max_workspace_size = 1 << 28
import onnx
print("[TRT Backend] Loading ONNX model ...")
onnx_model = onnx_model_file
if not isinstance(onnx_model_file, onnx.ModelProto):
onnx_model = onnx.load(onnx_model_file)
onnx_model = remove_initializer_from_input(onnx_model)
if not parser.parse(onnx_model.SerializeToString()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise Exception("ERROR: Failed to parse the ONNX file.")
if shape_info is None:
builder.max_batch_size = 1
for i in range(len(onnx_model.graph.input)):
input_shape = [
x.dim_value
for x in onnx_model.graph.input[0]
.type.tensor_type.shape.dim
]
for s in input_shape:
assert s > 0, "In static shape mode, the input of onnx model should be fixed, but now it's {}".format(
onnx_model.graph.input[i])
else:
max_batch_size = 1
if shape_info is not None:
assert len(
shape_info
) == network.num_inputs, "Length of shape_info: {} is not same with length of model input: {}".format(
len(shape_info), network.num_inputs)
profile = builder.create_optimization_profile()
for k, v in shape_info.items():
if v[2][0] > max_batch_size:
max_batch_size = v[2][0]
print("[TRT Backend] optimize shape: ", k, v[0], v[1],
v[2])
profile.set_shape(k, v[0], v[1], v[2])
config.add_optimization_profile(profile)
if max_batch_size > self.max_batch_size:
self.max_batch_size = max_batch_size
builder.max_batch_size = self.max_batch_size
print("[TRT Backend] Completed parsing of ONNX file.")
print(
"[TRT Backend] Building an engine from onnx model may take a while..."
)
plan = builder.build_serialized_network(network, config)
print("[TRT Backend] Start Creating Engine.")
self.engine = runtime.deserialize_cuda_engine(plan)
print("[TRT Backend] Completed Creating Engine.")
if engine_file_path is not None:
with open(engine_file_path, "wb") as f:
f.write(self.engine.serialize())
self.context = self.engine.create_execution_context()
if shape_info is not None:
self.context.active_optimization_profile = 0
self.stream = cuda.Stream()
self.bindings = []
self.inputs = []
self.outputs = []
for binding in self.engine:
self.bindings.append(0)
if self.engine.binding_is_input(binding):
self.inputs.append(HostDeviceMem(None, None))
else:
self.outputs.append(HostDeviceMem(None, None))
print("[TRT Backend] Completed TrtEngine init ...")
def infer(self, input_data):
assert len(self.inputs) == len(
input_data
), "Length of input_data: {} is not same with length of input: {}".format(
len(input_data), len(self.inputs))
self.allocate_buffers(input_data)
return self.do_inference_v2(
self.context,
bindings=self.bindings,
inputs=self.inputs,
outputs=self.outputs,
stream=self.stream)
def do_inference_v2(self, context, bindings, inputs, outputs, stream):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[
cuda.memcpy_dtoh_async(out.host, out.device, stream)
for out in outputs
]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]
def allocate_buffers(self, input_data):
input_idx = 0
output_idx = 0
for binding in self.engine:
idx = self.engine.get_binding_index(binding)
if self.engine.binding_is_input(binding):
if not input_data[input_idx].flags['C_CONTIGUOUS']:
input_data[input_idx] = np.ascontiguousarray(input_data[
input_idx])
self.context.set_binding_shape(idx,
(input_data[input_idx].shape))
self.inputs[input_idx].host = input_data[input_idx]
nbytes = input_data[input_idx].nbytes
if self.inputs[input_idx].nbytes < nbytes:
self.inputs[input_idx].nbytes = nbytes
self.inputs[input_idx].device = cuda.mem_alloc(nbytes)
self.bindings[idx] = int(self.inputs[input_idx].device)
input_idx += 1
else:
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
shape = self.context.get_binding_shape(idx)
self.outputs[output_idx].host = np.ascontiguousarray(
np.empty(
shape, dtype=dtype))
nbytes = self.outputs[output_idx].host.nbytes
if self.outputs[output_idx].nbytes < nbytes:
self.outputs[output_idx].nbytes = nbytes
self.outputs[output_idx].device = cuda.mem_alloc(
self.outputs[output_idx].host.nbytes)
self.bindings[idx] = int(self.outputs[output_idx].device)
output_idx += 1
...@@ -12,19 +12,68 @@ ...@@ -12,19 +12,68 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import warnings
import cv2 warnings.filterwarnings("ignore")
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import os import os
import time import sys
import random import numpy as np
import argparse import argparse
from tqdm import tqdm
import pkg_resources as pkg
import time
import cv2
import paddle
import onnx
sys.path.append("../")
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset
import trt_backend
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--onnx_model_file',
type=str,
default='yolov7_tiny_quant_onnx/quant_model.onnx',
help="onnx model file path.")
parser.add_argument(
'--calibration_file',
type=str,
default='yolov7_tiny_quant_onnx/calibration.cache',
help="quant onnx model calibration cache file.")
parser.add_argument(
'--image_file',
type=str,
default=None,
help="image path, if set image_file, it will not eval coco.")
parser.add_argument(
'--dataset_dir',
type=str,
default='dataset/coco',
help="COCO dataset dir.")
parser.add_argument(
'--val_image_dir',
type=str,
default='val2017',
help="COCO dataset val image dir.")
parser.add_argument(
'--val_anno_path',
type=str,
default='annotations/instances_val2017.json',
help="COCO dataset anno path.")
parser.add_argument(
'--precision_mode',
type=str,
default='fp32',
help="support fp32/fp16/int8.")
parser.add_argument(
'--batch_size', type=int, default=1, help="Batch size of model input.")
return parser
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
EXPLICIT_PRECISION = 1 << (
int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
# load coco labels # load coco labels
CLASS_LABEL = [ CLASS_LABEL = [
...@@ -68,72 +117,6 @@ def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)): ...@@ -68,72 +117,6 @@ def preprocess(image, input_size, mean=None, std=None, swap=(2, 0, 1)):
return padded_img, r return padded_img, r
def postprocess(predictions, ratio):
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
boxes_xyxy /= ratio
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
return dets
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep
def multiclass_nms(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy"""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)
def get_color_map_list(num_classes): def get_color_map_list(num_classes):
color_map = num_classes * [0, 0, 0] color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes): for i in range(0, num_classes):
...@@ -181,198 +164,158 @@ def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None): ...@@ -181,198 +164,158 @@ def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
return img return img
def get_engine(precision, model_file_path): def get_current_memory_mb():
# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) """
TRT_LOGGER = trt.Logger() It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
builder = trt.Builder(TRT_LOGGER) And this function Current program is time-consuming.
config = builder.create_builder_config() """
if precision == 'int8': try:
network = builder.create_network(EXPLICIT_BATCH | EXPLICIT_PRECISION) pkg.require('pynvml')
except:
from pip._internal import main
main(['install', 'pynvml'])
try:
pkg.require('psutil')
except:
from pip._internal import main
main(['install', 'psutil'])
try:
pkg.require('GPUtil')
except:
from pip._internal import main
main(['install', 'GPUtil'])
import pynvml
import psutil
import GPUtil
gpu_id = int(os.environ.get('CUDA_VISIBLE_DEVICES', 0))
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
gpus = GPUtil.getGPUs()
if gpu_id is not None and len(gpus) > 0:
gpu_percent = gpus[gpu_id].load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4)
def load_trt_engine():
model = onnx.load(FLAGS.onnx_model_file)
model_name = os.path.split(FLAGS.onnx_model_file)[-1].rstrip('.onnx')
if FLAGS.precision_mode == "int8":
engine_file = "{}_quant_model.trt".format(model_name)
assert os.path.exists(FLAGS.calibration_file)
trt_engine = trt_backend.TrtEngine(
model,
max_batch_size=1,
precision_mode=FLAGS.precision_mode,
engine_file_path=engine_file,
calibration_cache_file=FLAGS.calibration_file)
else: else:
network = builder.create_network(EXPLICIT_BATCH) engine_file = "{}_{}_model.trt".format(model_name, FLAGS.precision_mode)
parser = trt.OnnxParser(network, TRT_LOGGER) trt_engine = trt_backend.TrtEngine(
model,
runtime = trt.Runtime(TRT_LOGGER) max_batch_size=1,
if model_file_path.endswith('.trt'): precision_mode=FLAGS.precision_mode,
# If a serialized engine exists, use it instead of building an engine. engine_file_path=engine_file)
print("Reading engine from file {}".format(model_file_path)) return trt_engine
with open(model_file_path,
"rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read()) def eval():
for i in range(network.num_layers): trt_engine = load_trt_engine()
layer = network.get_layer(i) bboxes_list, bbox_nums_list, image_id_list = [], [], []
print(i, layer.name) cpu_mems, gpu_mems = 0, 0
return engine sample_nums = len(val_loader)
else: with tqdm(
config.max_workspace_size = 1 << 30 total=sample_nums,
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
if precision == "fp16": ncols=80) as t:
if not builder.platform_has_fast_fp16: for data in val_loader:
print("FP16 is not supported natively on this platform/device") data_all = {k: np.array(v) for k, v in data.items()}
else: outs = trt_engine.infer([data_all['image']])
config.set_flag(trt.BuilderFlag.FP16) outs = np.array(outs).reshape(1, -1, 85)
elif precision == "int8": postprocess = YOLOPostProcess(
if not builder.platform_has_fast_int8: score_threshold=0.001, nms_threshold=0.65, multi_label=True)
print("INT8 is not supported natively on this platform/device") res = postprocess(np.array(outs), data_all['scale_factor'])
else: bboxes_list.append(res['bbox'])
if builder.platform_has_fast_fp16: bbox_nums_list.append(res['bbox_num'])
# Also enable fp16, as some layers may be even more efficient in fp16 than int8 image_id_list.append(np.array(data_all['im_id']))
config.set_flag(trt.BuilderFlag.FP16) cpu_mem, gpu_mem = get_current_memory_mb()
config.set_flag(trt.BuilderFlag.INT8) cpu_mems += cpu_mem
gpu_mems += gpu_mem
builder.max_batch_size = 1 t.update()
print('Loading ONNX file from path {}...'.format(model_file_path)) print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format(
with open(model_file_path, 'rb') as model: cpu_mems / sample_nums, gpu_mems / sample_nums))
print('Beginning ONNX file parsing')
if not parser.parse(model.read()): coco_metric(anno_file, bboxes_list, bbox_nums_list, image_id_list)
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error)) def infer():
return None origin_img = cv2.imread(FLAGS.image_file)
input_shape = [640, 640]
print('Completed parsing of ONNX file') input_image, scale_factor = preprocess(origin_img, input_shape)
print('Building an engine from file {}; this may take a while...'. input_image = np.expand_dims(input_image, axis=0)
format(model_file_path)) scale_factor = np.array([[scale_factor, scale_factor]])
plan = builder.build_serialized_network(network, config) trt_engine = load_trt_engine()
engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine") repeat = 100
with open(model_file_path, "wb") as f: cpu_mems, gpu_mems = 0, 0
f.write(engine.serialize())
for i in range(network.num_layers):
layer = network.get_layer(i)
print(i, layer.name)
return engine
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(
binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
def run_inference(context, bindings, inputs, outputs, stream):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]
def main(args):
onnx_model = args.model_path
img_path = args.image_file
num_class = len(CLASS_LABEL)
repeat = 1000
engine = get_engine(args.precision, onnx_model)
model_all_names = []
for idx in range(engine.num_bindings):
is_input = engine.binding_is_input(idx)
name = engine.get_binding_name(idx)
op_type = engine.get_binding_dtype(idx)
model_all_names.append(name)
shape = engine.get_binding_shape(idx)
print('input id:', idx, ' is input: ', is_input, ' binding name:',
name, ' shape:', shape, 'type: ', op_type)
context = engine.create_execution_context()
print('Allocate buffers ...')
inputs, outputs, bindings, stream = allocate_buffers(engine)
print("TRT set input ...")
origin_img = cv2.imread(img_path)
input_shape = [args.img_shape, args.img_shape]
input_image, ratio = preprocess(origin_img, input_shape)
inputs[0].host = np.expand_dims(input_image, axis=0)
for _ in range(0, 50):
trt_outputs = run_inference(
context,
bindings=bindings,
inputs=inputs,
outputs=outputs,
stream=stream)
time1 = time.time()
for _ in range(0, repeat): for _ in range(0, repeat):
trt_outputs = run_inference( outs = trt_engine.infer(input_image)
context, cpu_mem, gpu_mem = get_current_memory_mb()
bindings=bindings, cpu_mems += cpu_mem
inputs=inputs, gpu_mems += gpu_mem
outputs=outputs, print('Avg cpu_mem:{} MB, avg gpu_mem: {} MB'.format(cpu_mems / repeat,
stream=stream) gpu_mems / repeat))
time2 = time.time()
# total time cost(ms)
total_inference_cost = (time2 - time1) * 1000
print("model path: ", onnx_model, " precision: ", args.precision)
print("In TensorRT, ",
"average latency is : {} ms".format(total_inference_cost / repeat))
# Do postprocess # Do postprocess
output = trt_outputs[0] outs = np.array(outs).reshape(1, -1, 85)
predictions = np.reshape(output, (1, -1, int(5 + num_class)))[0] postprocess = YOLOPostProcess(
dets = postprocess(predictions, ratio) score_threshold=0.1, nms_threshold=0.45, multi_label=False)
res = postprocess(np.array(outs), scale_factor)
# Draw rectangles and labels on the original image # Draw rectangles and labels on the original image
dets = res['bbox']
if dets is not None: if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:, : final_boxes, final_scores, final_class = dets[:, 2:], dets[:,
4], dets[:, 4], dets[:, 1], dets[:,
5] 0]
origin_img = draw_box( origin_img = draw_box(
origin_img, origin_img,
final_boxes, final_boxes,
final_scores, final_scores,
final_cls_inds, final_class,
conf=0.5, conf=0.5,
class_names=CLASS_LABEL) class_names=CLASS_LABEL)
cv2.imwrite('output.jpg', origin_img) cv2.imwrite('output.jpg', origin_img)
print('The prediction results are saved in output.jpg.') print('The prediction results are saved in output.jpg.')
if __name__ == "__main__": def main():
parser = argparse.ArgumentParser() if FLAGS.image_file:
parser.add_argument( infer()
'--model_path', else:
type=str, global val_loader
default="quant_model.onnx", dataset = COCOValDataset(
help="inference model filepath") dataset_dir=FLAGS.dataset_dir,
parser.add_argument( image_dir=FLAGS.val_image_dir,
'--image_file', type=str, default="bus.jpg", help="image path") anno_path=FLAGS.val_anno_path)
parser.add_argument( global anno_file
'--precision', type=str, default='fp32', help="support fp32/fp16/int8.") anno_file = dataset.ann_file
parser.add_argument('--img_shape', type=int, default=640, help="input_size") val_loader = paddle.io.DataLoader(
args = parser.parse_args() dataset, batch_size=FLAGS.batch_size, drop_last=True)
main(args) eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
paddle.set_device('cpu')
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册