diff --git a/deploy/python/README.md b/deploy/python/README.md index 4b28f0cf83da7d6195aad574ccdd5a3560d614eb..cb32ad468b58ff1d0cf0f9cb5f625822c8959393 100644 --- a/deploy/python/README.md +++ b/deploy/python/README.md @@ -91,6 +91,8 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc | --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False | | --cpu_threads | Option| 设置cpu线程数,默认为1 | | --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False | +| --save_results | Option| 是否在文件夹下将图片的预测结果以JSON的形式保存 | + 说明: diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 09eec1153796ae43f41f0881db22029ed6056e45..fd43874c3d883a39d68a435ee43667b18b254d1c 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -15,6 +15,8 @@ import os import yaml import glob +import json +from pathlib import Path from functools import reduce import cv2 @@ -233,7 +235,8 @@ class Detector(object): image_list, run_benchmark=False, repeats=1, - visual=True): + visual=True, + save_file=None): batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size) results = [] for i in range(batch_loop_cnt): @@ -293,6 +296,10 @@ class Detector(object): if visual: print('Test iter {}'.format(i)) + if save_file is not None: + Path(self.output_dir).mkdir(exist_ok=True) + self.format_coco_results(image_list, results, save_file=save_file) + results = self.merge_batch_result(results) return results @@ -313,7 +320,7 @@ class Detector(object): if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) out_path = os.path.join(self.output_dir, video_out_name) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') + fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) index = 1 while (1): @@ -337,6 +344,68 @@ class Detector(object): break writer.release() + @staticmethod + def format_coco_results(image_list, results, save_file=None): + coco_results = [] + image_id = 0 + + for result in results: + start_idx = 0 + for box_num in result['boxes_num']: + idx_slice = slice(start_idx, start_idx + box_num) + start_idx += box_num + + image_file = image_list[image_id] + image_id += 1 + + if 'boxes' in result: + boxes = result['boxes'][idx_slice, :] + per_result = [ + { + 'image_file': image_file, + 'bbox': + [box[2], box[3], box[4] - box[2], + box[5] - box[3]], # xyxy -> xywh + 'score': box[1], + 'category_id': int(box[0]), + } for k, box in enumerate(boxes.tolist()) + ] + + elif 'segm' in result: + import pycocotools.mask as mask_util + + scores = result['score'][idx_slice].tolist() + category_ids = result['label'][idx_slice].tolist() + segms = result['segm'][idx_slice, :] + rles = [ + mask_util.encode( + np.array( + mask[:, :, np.newaxis], + dtype=np.uint8, + order='F'))[0] for mask in segms + ] + for rle in rles: + rle['counts'] = rle['counts'].decode('utf-8') + + per_result = [{ + 'image_file': image_file, + 'segmentation': rle, + 'score': scores[k], + 'category_id': category_ids[k], + } for k, rle in enumerate(rles)] + + else: + raise RuntimeError('') + + # per_result = [item for item in per_result if item['score'] > threshold] + coco_results.extend(per_result) + + if save_file: + with open(os.path.join(save_file), 'w') as f: + json.dump(coco_results, f) + + return coco_results + class DetectorSOLOv2(Detector): """ @@ -807,7 +876,10 @@ def main(): if FLAGS.image_dir is None and FLAGS.image_file is not None: assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) - detector.predict_image(img_list, FLAGS.run_benchmark, repeats=100) + save_file = os.path.join(FLAGS.output_dir, + 'results.json') if FLAGS.save_results else None + detector.predict_image( + img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file) if not FLAGS.run_benchmark: detector.det_times.info(average=True) else: diff --git a/deploy/python/utils.py b/deploy/python/utils.py index c542f0176494e03312516574077815fbdd2d6d4c..41dc7ae9e81f49bdd08e0917d50b21ac00f2e527 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -156,6 +156,12 @@ def argsparser(): type=ast.literal_eval, default=False, help="Whether do random padding for action recognition.") + parser.add_argument( + "--save_results", + type=bool, + default=False, + help="Whether save detection result to file using coco format") + return parser