未验证 提交 6bc578cc 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add ppyoloe onnx-trt demo (#6743)

上级 486121ea
# PP-YOLOE 转ONNX-TRT教程
本教程内容为:使用PP-YOLOE模型导出转换为ONNX格式,并定制化修改网络,使用[EfficientNMS_TRT](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin) OP,
可成功运行在[TensorRT](https://github.com/NVIDIA/TensorRT)上,示例仅供参考
## 1. 环境依赖
CUDA 10.2 + [cudnn 8.2.1](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) + [TensorRT 8.2](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-821/install-guide/index.htm)
```commandline
onnx
onnxruntime
paddle2onnx
```
## 2. Paddle模型导出
```commandline
python tools/export_model.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams trt=True exclude_nms=True
```
## 3. ONNX模型转换 + 定制化修改EfficientNMS_TRT
```commandline
python deploy/third_engine/demo_onnx_trt/onnx_custom.py --onnx_file=output_inference/ppyoloe_crn_l_300e_coco/ppyoloe_crn_l_300e_coco.onnx --model_dir=output_inference/ppyoloe_crn_l_300e_coco/ --opset_version=11
```
## 4. TensorRT Engine
```commandline
trtexec --onnx=output_inference/ppyoloe_crn_l_300e_coco/ppyoloe_crn_l_300e_coco.onnx --saveEngine=ppyoloe_crn_l_300e_coco.engine
```
**注意**:若运行报错,可尝试添加`--tacticSources=-cublasLt,+cublas`参数解决
## 5. 运行TensorRT推理
```commandline
python deploy/third_engine/demo_onnx_trt/trt_infer.py --infer_cfg=output_inference/ppyoloe_crn_l_300e_coco/infer_cfg.yml --trt_engine=ppyoloe_crn_l_300e_coco.engine --image_file=demo/000000014439.jpg
```
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import onnx
import onnx_graphsurgeon
import numpy as np
from collections import OrderedDict
from paddle2onnx.command import program2onnx
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--onnx_file', required=True, type=str, help='onnx model path')
parser.add_argument(
'--model_dir',
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--opset_version",
type=int,
default=11,
help="set onnx opset version to export")
parser.add_argument(
'--topk_all', type=int, default=300, help='topk objects for every images')
parser.add_argument(
'--iou_thres', type=float, default=0.7, help='iou threshold for NMS')
parser.add_argument(
'--conf_thres', type=float, default=0.01, help='conf threshold for NMS')
def main(FLAGS):
assert os.path.exists(FLAGS.onnx_file)
onnx_model = onnx.load(FLAGS.onnx_file)
graph = onnx_graphsurgeon.import_onnx(onnx_model)
graph.toposort()
graph.fold_constants()
graph.cleanup()
num_anchors = graph.outputs[1].shape[2]
num_classes = graph.outputs[1].shape[1]
scores = onnx_graphsurgeon.Variable(
name='scores', shape=[-1, num_anchors, num_classes], dtype=np.float32)
graph.layer(
op='Transpose',
name='lastTranspose',
inputs=[graph.outputs[1]],
outputs=[scores],
attrs=OrderedDict(perm=[0, 2, 1]))
attrs = OrderedDict(
plugin_version="1",
background_class=-1,
max_output_boxes=FLAGS.topk_all,
score_threshold=FLAGS.conf_thres,
iou_threshold=FLAGS.iou_thres,
score_activation=False,
box_coding=0, )
outputs = [
onnx_graphsurgeon.Variable("num_dets", np.int32, [-1, 1]),
onnx_graphsurgeon.Variable("det_boxes", np.float32,
[-1, FLAGS.topk_all, 4]),
onnx_graphsurgeon.Variable("det_scores", np.float32,
[-1, FLAGS.topk_all]),
onnx_graphsurgeon.Variable("det_classes", np.int32,
[-1, FLAGS.topk_all])
]
graph.layer(
op='EfficientNMS_TRT',
name="batched_nms",
inputs=[graph.outputs[0], scores],
outputs=outputs,
attrs=attrs)
graph.outputs = outputs
graph.cleanup().toposort()
onnx.save(onnx_graphsurgeon.export_onnx(graph), FLAGS.onnx_file)
print(f"The modified onnx model is saved in {FLAGS.onnx_file}")
if __name__ == '__main__':
FLAGS = parser.parse_args()
if FLAGS.model_dir is not None:
assert os.path.exists(FLAGS.model_dir)
program2onnx(
model_dir=FLAGS.model_dir,
save_file=FLAGS.onnx_file,
model_filename="model.pdmodel",
params_filename="model.pdiparams",
opset_version=FLAGS.opset_version,
enable_onnx_checker=True)
main(FLAGS)
import numpy as np
import cv2
import copy
def decode_image(img_path):
with open(img_path, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
img_info = {
"im_shape": np.array(
im.shape[:2], dtype=np.float32),
"scale_factor": np.array(
[1., 1.], dtype=np.float32)
}
return im, img_info
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
class Pad(object):
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.fill_value = fill_value
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
h, w = self.size
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
im = canvas
return im, im_info
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def get_affine_transform(center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
class WarpAffine(object):
"""Warp affine the image
"""
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info
# keypoint preprocess
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
np.sin(theta) + 0.5 * size_target[0])
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
np.cos(theta) + 0.5 * size_target[1])
return matrix
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize, use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.
scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0,
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
else:
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
return image, im_info
class Compose:
def __init__(self, transforms):
self.transforms = []
for op_info in transforms:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
self.transforms.append(eval(op_type)(**new_op_info))
def __call__(self, img_path):
img, im_info = decode_image(img_path)
for t in self.transforms:
img, im_info = t(img, im_info)
inputs = copy.deepcopy(im_info)
inputs['image'] = np.ascontiguousarray(img.astype('float32'))
return inputs
coco_clsid2catid = {
0: 1,
1: 2,
2: 3,
3: 4,
4: 5,
5: 6,
6: 7,
7: 8,
8: 9,
9: 10,
10: 11,
11: 13,
12: 14,
13: 15,
14: 16,
15: 17,
16: 18,
17: 19,
18: 20,
19: 21,
20: 22,
21: 23,
22: 24,
23: 25,
24: 27,
25: 28,
26: 31,
27: 32,
28: 33,
29: 34,
30: 35,
31: 36,
32: 37,
33: 38,
34: 39,
35: 40,
36: 41,
37: 42,
38: 43,
39: 44,
40: 46,
41: 47,
42: 48,
43: 49,
44: 50,
45: 51,
46: 52,
47: 53,
48: 54,
49: 55,
50: 56,
51: 57,
52: 58,
53: 59,
54: 60,
55: 61,
56: 62,
57: 63,
58: 64,
59: 65,
60: 67,
61: 70,
62: 72,
63: 73,
64: 74,
65: 75,
66: 76,
67: 77,
68: 78,
69: 79,
70: 80,
71: 81,
72: 82,
73: 84,
74: 85,
75: 86,
76: 87,
77: 88,
78: 89,
79: 90
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
from collections import OrderedDict
import os
import yaml
import json
import glob
import argparse
from preprocess import Compose
from preprocess import coco_clsid2catid
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml")
parser.add_argument(
"--trt_engine", required=True, type=str, help="trt engine path")
parser.add_argument("--image_dir", type=str)
parser.add_argument("--image_file", type=str)
parser.add_argument(
"--repeats",
type=int,
default=1,
help="Repeat the running test `repeats` times in benchmark")
parser.add_argument(
"--save_coco",
action='store_true',
default=False,
help="Whether to save coco results")
parser.add_argument(
"--coco_file", type=str, default="results.json", help="coco results path")
TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
}
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--image_file or --image_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
print("Found {} inference images in total.".format(len(images)))
return images
class PredictConfig(object):
"""set config of preprocess, postprocess and visualize
Args:
infer_config (str): path of infer_cfg.yml
"""
def __init__(self, infer_config):
# parsing Yaml config for Preprocess
with open(infer_config) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.label_list = yml_conf['label_list']
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
self.mask = yml_conf.get("mask", False)
self.tracker = yml_conf.get("tracker", None)
self.nms = yml_conf.get("NMS", None)
self.fpn_stride = yml_conf.get("fpn_stride", None)
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_trt_engine(engine_path):
assert os.path.exists(engine_path)
print("Reading engine from file {}".format(engine_path))
with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def predict_image(infer_config, engine, img_list, save_coco=False, repeats=1):
# load preprocess transforms
transforms = Compose(infer_config.preprocess_infos)
stream = cuda.Stream()
coco_results = []
num_data = len(img_list)
avg_time = []
with engine.create_execution_context() as context:
# Allocate host and device buffers
bindings = create_trt_bindings(engine, context)
# warmup
run_trt_context(context, bindings, stream, repeats=10)
# predict image
for i, img_path in enumerate(img_list):
inputs = transforms(img_path)
inputs_name = [k for k, v in bindings.items() if v['is_input']]
inputs = {
k: inputs[k][None, ]
for k in inputs.keys() if k in inputs_name
}
# run infer
for k, v in inputs.items():
bindings[k]['cpu_data'][...] = v
output = run_trt_context(context, bindings, stream, repeats=repeats)
print(f"{i + 1}/{num_data} infer time: {output['infer_time']} ms.")
avg_time.append(output['infer_time'])
# get output
for k, v in output.items():
if k in bindings.keys():
output[k] = np.reshape(v, bindings[k]['shape'])
if save_coco:
coco_results.extend(
format_coco_results(os.path.split(img_path)[-1], output))
avg_time = np.mean(avg_time)
print(
f"Run on {num_data} data, repeats {repeats} times, avg time: {avg_time} ms."
)
if save_coco:
with open(FLAGS.coco_file, 'w') as f:
json.dump(coco_results, f)
print(f"save coco json to {FLAGS.coco_file}")
def create_trt_bindings(engine, context):
bindings = OrderedDict()
for name in engine:
binding_idx = engine.get_binding_index(name)
size = trt.volume(context.get_binding_shape(binding_idx))
dtype = trt.nptype(engine.get_binding_dtype(name))
shape = list(engine.get_binding_shape(binding_idx))
if shape[0] == -1:
shape[0] = 1
bindings[name] = {
"idx": binding_idx,
"size": size,
"dtype": dtype,
"shape": shape,
"cpu_data": None,
"cuda_ptr": None,
"is_input": True if engine.binding_is_input(name) else False
}
if engine.binding_is_input(name):
bindings[name]['cpu_data'] = np.random.randn(
*shape).astype(np.float32)
bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
'cpu_data'].nbytes)
else:
bindings[name]['cpu_data'] = cuda.pagelocked_empty(size, dtype)
bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
'cpu_data'].nbytes)
return bindings
def run_trt_context(context, bindings, stream, repeats=1):
# Transfer input data to the GPU.
for k, v in bindings.items():
if v['is_input']:
cuda.memcpy_htod_async(v['cuda_ptr'], v['cpu_data'], stream)
in_bindings = [int(v['cuda_ptr']) for k, v in bindings.items()]
output_data = {}
avg_time = []
for _ in range(repeats):
# Run inference
t1 = time.time()
context.execute_async_v2(
bindings=in_bindings, stream_handle=stream.handle)
# Transfer prediction output from the GPU.
for k, v in bindings.items():
if not v['is_input']:
cuda.memcpy_dtoh_async(v['cpu_data'], v['cuda_ptr'], stream)
output_data[k] = v['cpu_data']
# Synchronize the stream
stream.synchronize()
t2 = time.time()
avg_time.append(t2 - t1)
output_data['infer_time'] = np.mean(avg_time) * 1000
return output_data
def format_coco_results(file_name, result):
try:
image_id = int(os.path.splitext(file_name)[0])
except:
image_id = file_name
num_dets = result['num_dets'].tolist()
det_classes = result['det_classes'].tolist()
det_scores = result['det_scores'].tolist()
det_boxes = result['det_boxes'].tolist()
per_result = [
{
'image_id': image_id,
'category_id': coco_clsid2catid[int(det_classes[0][idx])],
'file_name': file_name,
'bbox': [
det_boxes[0][idx][0], det_boxes[0][idx][1],
det_boxes[0][idx][2] - det_boxes[0][idx][0],
det_boxes[0][idx][3] - det_boxes[0][idx][1]
], # xyxy -> xywh
'score': det_scores[0][idx]
} for idx in range(num_dets[0][0])
]
return per_result
if __name__ == '__main__':
FLAGS = parser.parse_args()
# load image list
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
# load trt engine
engine = load_trt_engine(FLAGS.trt_engine)
# load infer config
infer_config = PredictConfig(FLAGS.infer_cfg)
predict_image(infer_config, engine, img_list, FLAGS.save_coco,
FLAGS.repeats)
print('Done!')
......@@ -46,7 +46,9 @@ class ESEAttn(nn.Layer):
@register
class PPYOLOEHead(nn.Layer):
__shared__ = ['num_classes', 'eval_size', 'trt', 'exclude_nms']
__shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms', 'exclude_post_process'
]
__inject__ = ['static_assigner', 'assigner', 'nms']
def __init__(self,
......@@ -69,7 +71,8 @@ class PPYOLOEHead(nn.Layer):
'dfl': 0.5,
},
trt=False,
exclude_nms=False):
exclude_nms=False,
exclude_post_process=False):
super(PPYOLOEHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
......@@ -90,6 +93,7 @@ class PPYOLOEHead(nn.Layer):
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
self.exclude_nms = exclude_nms
self.exclude_post_process = exclude_post_process
# stem
self.stem_cls = nn.LayerList()
self.stem_reg = nn.LayerList()
......@@ -369,14 +373,19 @@ class PPYOLOEHead(nn.Layer):
pred_bboxes = batch_distance2bbox(anchor_points,
pred_dist.transpose([0, 2, 1]))
pred_bboxes *= stride_tensor
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
if self.exclude_nms:
# `exclude_nms=True` just use in benchmark
return pred_bboxes.sum(), pred_scores.sum()
if self.exclude_post_process:
return paddle.concat(
[pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1), None
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y],
axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
if self.exclude_nms:
# `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册