diff --git a/deploy/third_engine/demo_openvino/python/README.md b/deploy/third_engine/demo_openvino/python/README.md index 12792abc206648206c98842650a0c235a4819cec..34ac8cae8b41a122390d07d7430fea7c81417b4a 100644 --- a/deploy/third_engine/demo_openvino/python/README.md +++ b/deploy/third_engine/demo_openvino/python/README.md @@ -1,6 +1,6 @@ # PicoDet OpenVINO Benchmark Demo -本文件夹提供利用[Intel's OpenVINO Toolkit](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html)进行PicoDet测速的Benchmark Demo +本文件夹提供利用[Intel's OpenVINO Toolkit](https://software.intel.com/content/www/us/en/develop/tools/openvino-toolkit.html)进行PicoDet测速的Benchmark Demo与带后处理的模型Inference Demo。 ## 安装 OpenVINO Toolkit @@ -13,9 +13,9 @@ pip install openvino==2022.1.0 详细安装步骤,可参考[OpenVINO官网](https://docs.openvinotoolkit.org/latest/get_started_guides.html) -## 测试 +## Benchmark测试 -- 准备测试模型:根据[PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)中【导出及转换模型】步骤,采用不包含后处理的方式导出模型(`-o export.benchmark=True` ),并生成待测试模型简化后的onnx模型(可在下文链接中可直接下载)。同时在本目录下新建```out_onnxsim```文件夹,将导出的onnx模型放在该目录下。 +- 准备测试模型:根据[PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)中【导出及转换模型】步骤,采用不包含后处理的方式导出模型(`-o export.benchmark=True` ),并生成待测试模型简化后的onnx模型(可在下文链接中直接下载)。同时在本目录下新建```out_onnxsim```文件夹,将导出的onnx模型放在该目录下。 - 准备测试所用图片:本demo默认利用PaddleDetection/demo/[000000014439.jpg](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/demo/000000014439.jpg) @@ -30,7 +30,7 @@ python openvino_benchmark.py --img_path ..\..\..\..\demo\000000014439.jpg --onnx ``` - 注意:```--in_shape```为对应模型输入size,默认为320 -### Inference images +### Inference images(w/o 后处理) ```shell # Linux @@ -38,20 +38,38 @@ python openvino_benchmark.py --benchmark 0 --img_path ../../../../demo/000000014 # Windows python openvino_benchmark.py --benchmark 0 --img_path ..\..\..\..\demo\000000014439.jpg --onnx_path out_onnxsim\picodet_s_320_coco_lcnet.onnx --in_shape 320 ``` + +## Inference images(w/ 后处理, w/o NMS) + +- 准备测试模型:根据[PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet)中【导出及转换模型】步骤,采用**包含后处理**但**不包含NMS**的方式导出模型(`-o export.benchmark=False export.nms=False` ),并生成待测试模型简化后的onnx模型(可在下文链接中直接下载)。同时在本目录下新建```out_onnxsim_infer```文件夹,将导出的onnx模型放在该目录下。 + +- 准备测试所用图片:默认利用../../demo_onnxruntime/imgs/[bus.jpg](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/third_engine/demo_onnxruntime/imgs/bus.jpg) + +```shell +# Linux +python openvino_infer.py --img_path ../../demo_onnxruntime/imgs/bus.jpg --onnx_path out_onnxsim_infer/picodet_s_320_postproccesed_woNMS.onnx --in_shape 320 +# Windows +python openvino_infer.py --img_path ..\..\demo_onnxruntime\imgs\bus.jpg --onnx_path out_onnxsim_infer\picodet_s_320_postproccesed_woNMS.onnx --in_shape 320 +``` +- 结果: +
+ +
+ ## 结果 -测试结果如下: +- 测速结果如下: | 模型 | 输入尺寸 | ONNX | 预测时延[CPU](#latency)| | :-------- | :--------: | :---------------------: | :----------------: | -| PicoDet-XS | 320*320 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_320_coco_lcnet.onnx) | 3.9ms | -| PicoDet-XS | 416*416 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_416_coco_lcnet.onnx) | 6.1ms | -| PicoDet-S | 320*320 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet.onnx) | 4.8ms | -| PicoDet-S | 416*416 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_416_coco_lcnet.onnx) | 6.6ms | -| PicoDet-M | 320*320 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_320_coco_lcnet.onnx) | 8.2ms | -| PicoDet-M | 416*416 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_coco_lcnet.onnx) | 12.7ms | -| PicoDet-L | 320*320 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_320_coco_lcnet.onnx) | 11.5ms | -| PicoDet-L | 416*416 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_416_coco_lcnet.onnx) | 20.7ms | -| PicoDet-L | 640*640 | [model](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_640_coco.onnx) | 62.5ms | +| PicoDet-XS | 320*320 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_320_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_320_coco_lcnet.onnx) | 3.9ms | +| PicoDet-XS | 416*416 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_416_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_xs_416_coco_lcnet.onnx) | 6.1ms | +| PicoDet-S | 320*320 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_320_coco_lcnet.onnx) | 4.8ms | +| PicoDet-S | 416*416 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_416_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_s_416_coco_lcnet.onnx) | 6.6ms | +| PicoDet-M | 320*320 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_320_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_320_coco_lcnet.onnx) | 8.2ms | +| PicoDet-M | 416*416 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_m_416_coco_lcnet.onnx) | 12.7ms | +| PicoDet-L | 320*320 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_320_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_320_coco_lcnet.onnx) | 11.5ms | +| PicoDet-L | 416*416 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_416_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_416_coco_lcnet.onnx) | 20.7ms | +| PicoDet-L | 640*640 | [( w/ 后处理;w/o NMS)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_640_lcnet_postproccesed_woNMS.onnx) | [( w/o 后处理)](https://paddledet.bj.bcebos.com/deploy/third_engine/picodet_l_640_coco_lcnet.onnx) | 62.5ms | - 测试环境: 英特尔酷睿i7 10750H CPU。 diff --git a/deploy/third_engine/demo_openvino/python/openvino_benchmark.py b/deploy/third_engine/demo_openvino/python/openvino_benchmark.py index c83605d895953d0b8d5101ff1c9e1dadf31de306..f21a8d5d1ed83c159818d2b405d1b5c9e5daa927 100644 --- a/deploy/third_engine/demo_openvino/python/openvino_benchmark.py +++ b/deploy/third_engine/demo_openvino/python/openvino_benchmark.py @@ -339,7 +339,7 @@ if __name__ == '__main__': parser.add_argument( '--img_path', type=str, - default='demo/000000014439.jpg', + default='../../../../demo/000000014439.jpg', help="image path") parser.add_argument( '--onnx_path', diff --git a/deploy/third_engine/demo_openvino/python/openvino_infer.py b/deploy/third_engine/demo_openvino/python/openvino_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad51022b1793e7b6430025a7c71cc0de7658c8c --- /dev/null +++ b/deploy/third_engine/demo_openvino/python/openvino_infer.py @@ -0,0 +1,267 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import argparse +from scipy.special import softmax +from openvino.runtime import Core + + +def image_preprocess(img_path, re_shape): + img = cv2.imread(img_path) + img = cv2.resize( + img, (re_shape, re_shape), interpolation=cv2.INTER_LANCZOS4) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, [2, 0, 1]) / 255 + img = np.expand_dims(img, 0) + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + return img.astype(np.float32) + + +def get_color_map_list(num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_box(srcimg, results, class_label): + label_list = list( + map(lambda x: x.strip(), open(class_label, 'r').readlines())) + for i in range(len(results)): + color_list = get_color_map_list(len(label_list)) + clsid2color = {} + classid, conf = int(results[i, 0]), results[i, 1] + xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int( + results[i, 4]), int(results[i, 5]) + + if classid not in clsid2color: + clsid2color[classid] = color_list[classid] + color = tuple(clsid2color[classid]) + + cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2) + print(label_list[classid] + ': ' + str(round(conf, 3))) + cv2.putText( + srcimg, + label_list[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, (0, 255, 0), + thickness=2) + return srcimg + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetNMS(object): + """ + Args: + input_shape (int): network input image size + scale_factor (float): scale factor of ori image + """ + + def __init__(self, + input_shape, + scale_x, + scale_y, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.input_shape = input_shape + self.scale_x = scale_x + self.scale_y = scale_y + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def __call__(self, decode_boxes, select_scores): + batch_size = 1 + out_boxes_list = [] + for batch_id in range(batch_size): + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, 0] *= self.scale_x + picked_box_probs[:, 2] *= self.scale_x + picked_box_probs[:, 1] *= self.scale_y + picked_box_probs[:, 3] *= self.scale_y + + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), + axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + return out_boxes_list + + +def detect(img_file, compiled_model, class_label): + output = compiled_model.infer_new_request({0: test_image}) + result_ie = list(output.values()) + + decode_boxes = [] + select_scores = [] + num_outs = int(len(result_ie) / 2) + for out_idx in range(num_outs): + decode_boxes.append(result_ie[out_idx]) + select_scores.append(result_ie[out_idx + num_outs]) + + image = cv2.imread(img_file, 1) + scale_x = image.shape[1] / test_image.shape[3] + scale_y = image.shape[0] / test_image.shape[2] + + nms = PicoDetNMS(test_image.shape[2:], scale_x, scale_y) + np_boxes = nms(decode_boxes, select_scores) + + res_image = draw_box(image, np_boxes, class_label) + + cv2.imwrite('res.jpg', res_image) + cv2.imshow("res", res_image) + cv2.waitKey() + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '--img_path', + type=str, + default='../../demo_onnxruntime/imgs/bus.jpg', + help="image path") + parser.add_argument( + '--onnx_path', + type=str, + default='out_onnxsim_infer/picodet_s_320_postproccesed_woNMS.onnx', + help="onnx filepath") + parser.add_argument('--in_shape', type=int, default=320, help="input_size") + parser.add_argument( + '--class_label', + type=str, + default='coco_label.txt', + help="class label file") + args = parser.parse_args() + + ie = Core() + net = ie.read_model(args.onnx_path) + test_image = image_preprocess(args.img_path, args.in_shape) + compiled_model = ie.compile_model(net, 'CPU') + + detect(args.img_path, compiled_model, args.class_label) diff --git a/docs/images/res.jpg b/docs/images/res.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f281fa3be0053d5a919da4ee36c6005e0664daa Binary files /dev/null and b/docs/images/res.jpg differ