未验证 提交 1ee341b9 编写于 作者: W Wenyu 提交者: GitHub

save detection results to file using coco format #5782 (#5787)

* save detection results to file using coco format

* update save docs
上级 fa250ff1
...@@ -91,6 +91,8 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc ...@@ -91,6 +91,8 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc
| --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False | | --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False |
| --cpu_threads | Option| 设置cpu线程数,默认为1 | | --cpu_threads | Option| 设置cpu线程数,默认为1 |
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False | | --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
| --save_results | Option| 是否在文件夹下将图片的预测结果以JSON的形式保存 |
说明: 说明:
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import os import os
import yaml import yaml
import glob import glob
import json
from pathlib import Path
from functools import reduce from functools import reduce
import cv2 import cv2
...@@ -233,7 +235,8 @@ class Detector(object): ...@@ -233,7 +235,8 @@ class Detector(object):
image_list, image_list,
run_benchmark=False, run_benchmark=False,
repeats=1, repeats=1,
visual=True): visual=True,
save_file=None):
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size) batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
results = [] results = []
for i in range(batch_loop_cnt): for i in range(batch_loop_cnt):
...@@ -293,6 +296,10 @@ class Detector(object): ...@@ -293,6 +296,10 @@ class Detector(object):
if visual: if visual:
print('Test iter {}'.format(i)) print('Test iter {}'.format(i))
if save_file is not None:
Path(self.output_dir).mkdir(exist_ok=True)
self.format_coco_results(image_list, results, save_file=save_file)
results = self.merge_batch_result(results) results = self.merge_batch_result(results)
return results return results
...@@ -313,7 +320,7 @@ class Detector(object): ...@@ -313,7 +320,7 @@ class Detector(object):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1 index = 1
while (1): while (1):
...@@ -337,6 +344,68 @@ class Detector(object): ...@@ -337,6 +344,68 @@ class Detector(object):
break break
writer.release() writer.release()
@staticmethod
def format_coco_results(image_list, results, save_file=None):
coco_results = []
image_id = 0
for result in results:
start_idx = 0
for box_num in result['boxes_num']:
idx_slice = slice(start_idx, start_idx + box_num)
start_idx += box_num
image_file = image_list[image_id]
image_id += 1
if 'boxes' in result:
boxes = result['boxes'][idx_slice, :]
per_result = [
{
'image_file': image_file,
'bbox':
[box[2], box[3], box[4] - box[2],
box[5] - box[3]], # xyxy -> xywh
'score': box[1],
'category_id': int(box[0]),
} for k, box in enumerate(boxes.tolist())
]
elif 'segm' in result:
import pycocotools.mask as mask_util
scores = result['score'][idx_slice].tolist()
category_ids = result['label'][idx_slice].tolist()
segms = result['segm'][idx_slice, :]
rles = [
mask_util.encode(
np.array(
mask[:, :, np.newaxis],
dtype=np.uint8,
order='F'))[0] for mask in segms
]
for rle in rles:
rle['counts'] = rle['counts'].decode('utf-8')
per_result = [{
'image_file': image_file,
'segmentation': rle,
'score': scores[k],
'category_id': category_ids[k],
} for k, rle in enumerate(rles)]
else:
raise RuntimeError('')
# per_result = [item for item in per_result if item['score'] > threshold]
coco_results.extend(per_result)
if save_file:
with open(os.path.join(save_file), 'w') as f:
json.dump(coco_results, f)
return coco_results
class DetectorSOLOv2(Detector): class DetectorSOLOv2(Detector):
""" """
...@@ -807,7 +876,10 @@ def main(): ...@@ -807,7 +876,10 @@ def main():
if FLAGS.image_dir is None and FLAGS.image_file is not None: if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
detector.predict_image(img_list, FLAGS.run_benchmark, repeats=100) save_file = os.path.join(FLAGS.output_dir,
'results.json') if FLAGS.save_results else None
detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
......
...@@ -156,6 +156,12 @@ def argsparser(): ...@@ -156,6 +156,12 @@ def argsparser():
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="Whether do random padding for action recognition.") help="Whether do random padding for action recognition.")
parser.add_argument(
"--save_results",
type=bool,
default=False,
help="Whether save detection result to file using coco format")
return parser return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册