未验证 提交 10e7fe23 编写于 作者: S shangliang Xu 提交者: GitHub

[deploy] alter save coco format json in deploy/python/infer.py (#6705)

上级 1d867e82
...@@ -75,23 +75,24 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc ...@@ -75,23 +75,24 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc
参数说明如下: 参数说明如下:
| 参数 | 是否必须|含义 | | 参数 | 是否必须| 含义 |
|-------|-------|----------| |-------|-------|---------------------------------------------------------------------------------------------|
| --model_dir | Yes| 上述导出的模型路径 | | --model_dir | Yes| 上述导出的模型路径 |
| --image_file | Option | 需要预测的图片 | | --image_file | Option | 需要预测的图片 |
| --image_dir | Option | 要预测的图片文件夹路径 | | --image_dir | Option | 要预测的图片文件夹路径 |
| --video_file | Option | 需要预测的视频 | | --video_file | Option | 需要预测的视频 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4| | --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4 |
| --device | Option | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`| | --device | Option | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU` |
| --run_mode | Option |使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)| | --run_mode | Option | 使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8) |
| --batch_size | Option |预测时的batch size,在指定`image_dir`时有效,默认为1 | | --batch_size | Option | 预测时的batch size,在指定`image_dir`时有效,默认为1 |
| --threshold | Option|预测得分的阈值,默认为0.5| | --threshold | Option| 预测得分的阈值,默认为0.5 |
| --output_dir | Option|可视化结果保存的根目录,默认为output/| | --output_dir | Option| 可视化结果保存的根目录,默认为output/ |
| --run_benchmark | Option| 是否运行benchmark,同时需指定`--image_file``--image_dir`,默认为False | | --run_benchmark | Option| 是否运行benchmark,同时需指定`--image_file``--image_dir`,默认为False |
| --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False | | --enable_mkldnn | Option | CPU预测中是否开启MKLDNN加速,默认为False |
| --cpu_threads | Option| 设置cpu线程数,默认为1 | | --cpu_threads | Option| 设置cpu线程数,默认为1 |
| --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False | | --trt_calib_mode | Option| TensorRT是否使用校准功能,默认为False。使用TensorRT的int8功能时,需设置为True,使用PaddleSlim量化后的模型时需要设置为False |
| --save_results | Option| 是否在文件夹下将图片的预测结果以JSON的形式保存 | | --save_images | Option| 是否保存可视化结果 |
| --save_results | Option| 是否在文件夹下将图片的预测结果以JSON的形式保存 |
说明: 说明:
...@@ -100,3 +101,4 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc ...@@ -100,3 +101,4 @@ python deploy/python/mot_keypoint_unite_infer.py --mot_model_dir=output_inferenc
- run_mode:paddle代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。 - run_mode:paddle代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html) - 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html)
- --run_benchmark如果设置为True,则需要安装依赖`pip install pynvml psutil GPUtil` - --run_benchmark如果设置为True,则需要安装依赖`pip install pynvml psutil GPUtil`
- 如果需要使用导出模型在coco数据集上进行评估,请在推理时添加`--save_results``--use_coco_category`参数用以保存coco评估所需要的json文件
...@@ -36,7 +36,7 @@ from picodet_postprocess import PicoDetPostProcess ...@@ -36,7 +36,7 @@ from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
# Global dictionary # Global dictionary
SUPPORT_MODELS = { SUPPORT_MODELS = {
...@@ -226,7 +226,7 @@ class Detector(object): ...@@ -226,7 +226,7 @@ class Detector(object):
match_threshold=0.6, match_threshold=0.6,
match_metric='iou', match_metric='iou',
visual=True, visual=True,
save_file=None): save_results=False):
# slice infer only support bs=1 # slice infer only support bs=1
results = [] results = []
try: try:
...@@ -295,14 +295,13 @@ class Detector(object): ...@@ -295,14 +295,13 @@ class Detector(object):
threshold=self.threshold) threshold=self.threshold)
results.append(merged_results) results.append(merged_results)
if visual: print('Test iter {}'.format(i))
print('Test iter {}'.format(i))
if save_file is not None:
Path(self.output_dir).mkdir(exist_ok=True)
self.format_coco_results(image_list, results, save_file=save_file)
results = self.merge_batch_result(results) results = self.merge_batch_result(results)
if save_results:
Path(self.output_dir).mkdir(exist_ok=True)
self.save_coco_results(
img_list, results, use_coco_category=FLAGS.use_coco_category)
return results return results
def predict_image(self, def predict_image(self,
...@@ -310,7 +309,7 @@ class Detector(object): ...@@ -310,7 +309,7 @@ class Detector(object):
run_benchmark=False, run_benchmark=False,
repeats=1, repeats=1,
visual=True, visual=True,
save_file=None): save_results=False):
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size) batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
results = [] results = []
for i in range(batch_loop_cnt): for i in range(batch_loop_cnt):
...@@ -367,14 +366,13 @@ class Detector(object): ...@@ -367,14 +366,13 @@ class Detector(object):
threshold=self.threshold) threshold=self.threshold)
results.append(result) results.append(result)
if visual: print('Test iter {}'.format(i))
print('Test iter {}'.format(i))
if save_file is not None:
Path(self.output_dir).mkdir(exist_ok=True)
self.format_coco_results(image_list, results, save_file=save_file)
results = self.merge_batch_result(results) results = self.merge_batch_result(results)
if save_results:
Path(self.output_dir).mkdir(exist_ok=True)
self.save_coco_results(
image_list, results, use_coco_category=FLAGS.use_coco_category)
return results return results
def predict_video(self, video_file, camera_id): def predict_video(self, video_file, camera_id):
...@@ -394,7 +392,7 @@ class Detector(object): ...@@ -394,7 +392,7 @@ class Detector(object):
if not os.path.exists(self.output_dir): if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir) os.makedirs(self.output_dir)
out_path = os.path.join(self.output_dir, video_out_name) out_path = os.path.join(self.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(* 'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1 index = 1
while (1): while (1):
...@@ -418,67 +416,62 @@ class Detector(object): ...@@ -418,67 +416,62 @@ class Detector(object):
break break
writer.release() writer.release()
@staticmethod def save_coco_results(self, image_list, results, use_coco_category=False):
def format_coco_results(image_list, results, save_file=None): bbox_results = []
coco_results = [] mask_results = []
image_id = 0 idx = 0
print("Start saving coco json files...")
for result in results: for i, box_num in enumerate(results['boxes_num']):
start_idx = 0 file_name = os.path.split(image_list[i])[-1]
for box_num in result['boxes_num']: if use_coco_category:
idx_slice = slice(start_idx, start_idx + box_num) img_id = int(os.path.splitext(file_name)[0])
start_idx += box_num else:
img_id = i
image_file = image_list[image_id]
image_id += 1 if 'boxes' in results:
boxes = results['boxes'][idx:idx + box_num].tolist()
if 'boxes' in result: bbox_results.extend([{
boxes = result['boxes'][idx_slice, :] 'image_id': img_id,
per_result = [ 'category_id': coco_clsid2catid[int(box[0])] \
{ if use_coco_category else int(box[0]),
'image_file': image_file, 'file_name': file_name,
'bbox': 'bbox': [box[2], box[3], box[4] - box[2],
[box[2], box[3], box[4] - box[2], box[5] - box[3]], # xyxy -> xywh
box[5] - box[3]], # xyxy -> xywh 'score': box[1]} for box in boxes])
'score': box[1],
'category_id': int(box[0]), if 'masks' in results:
} for k, box in enumerate(boxes.tolist()) import pycocotools.mask as mask_util
]
boxes = results['boxes'][idx:idx + box_num].tolist()
elif 'segm' in result: masks = results['masks'][i][:box_num].astype(np.uint8)
import pycocotools.mask as mask_util seg_res = []
for box, mask in zip(boxes, masks):
scores = result['score'][idx_slice].tolist() rle = mask_util.encode(
category_ids = result['label'][idx_slice].tolist() np.array(
segms = result['segm'][idx_slice, :] mask[:, :, None], dtype=np.uint8, order="F"))[0]
rles = [ if 'counts' in rle:
mask_util.encode( rle['counts'] = rle['counts'].decode("utf8")
np.array( seg_res.append({
mask[:, :, np.newaxis], 'image_id': img_id,
dtype=np.uint8, 'category_id': coco_clsid2catid[int(box[0])] \
order='F'))[0] for mask in segms if use_coco_category else int(box[0]),
] 'file_name': file_name,
for rle in rles:
rle['counts'] = rle['counts'].decode('utf-8')
per_result = [{
'image_file': image_file,
'segmentation': rle, 'segmentation': rle,
'score': scores[k], 'score': box[1]})
'category_id': category_ids[k], mask_results.extend(seg_res)
} for k, rle in enumerate(rles)]
else: idx += box_num
raise RuntimeError('')
# per_result = [item for item in per_result if item['score'] > threshold] if bbox_results:
coco_results.extend(per_result) bbox_file = os.path.join(self.output_dir, "bbox.json")
with open(bbox_file, 'w') as f:
if save_file: json.dump(bbox_results, f)
with open(os.path.join(save_file), 'w') as f: print(f"The bbox result is saved to {bbox_file}")
json.dump(coco_results, f) if mask_results:
mask_file = os.path.join(self.output_dir, "mask.json")
return coco_results with open(mask_file, 'w') as f:
json.dump(mask_results, f)
print(f"The mask result is saved to {mask_file}")
class DetectorSOLOv2(Detector): class DetectorSOLOv2(Detector):
...@@ -956,8 +949,6 @@ def main(): ...@@ -956,8 +949,6 @@ def main():
if FLAGS.image_dir is None and FLAGS.image_file is not None: if FLAGS.image_dir is None and FLAGS.image_file is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
save_file = os.path.join(FLAGS.output_dir,
'results.json') if FLAGS.save_results else None
if FLAGS.slice_infer: if FLAGS.slice_infer:
detector.predict_image_slice( detector.predict_image_slice(
img_list, img_list,
...@@ -966,10 +957,15 @@ def main(): ...@@ -966,10 +957,15 @@ def main():
FLAGS.combine_method, FLAGS.combine_method,
FLAGS.match_threshold, FLAGS.match_threshold,
FLAGS.match_metric, FLAGS.match_metric,
save_file=save_file) visual=FLAGS.save_images,
save_results=FLAGS.save_results)
else: else:
detector.predict_image( detector.predict_image(
img_list, FLAGS.run_benchmark, repeats=100, save_file=save_file) img_list,
FLAGS.run_benchmark,
repeats=100,
visual=FLAGS.save_images,
save_results=FLAGS.save_results)
if not FLAGS.run_benchmark: if not FLAGS.run_benchmark:
detector.det_times.info(average=True) detector.det_times.info(average=True)
else: else:
......
...@@ -109,6 +109,7 @@ def argsparser(): ...@@ -109,6 +109,7 @@ def argsparser():
parser.add_argument( parser.add_argument(
'--save_images', '--save_images',
action='store_true', action='store_true',
default=False,
help='Save visualization image results.') help='Save visualization image results.')
parser.add_argument( parser.add_argument(
'--save_mot_txts', '--save_mot_txts',
...@@ -159,9 +160,14 @@ def argsparser(): ...@@ -159,9 +160,14 @@ def argsparser():
help="Whether do random padding for action recognition.") help="Whether do random padding for action recognition.")
parser.add_argument( parser.add_argument(
"--save_results", "--save_results",
type=bool, action='store_true',
default=False, default=False,
help="Whether save detection result to file using coco format") help="Whether save detection result to file using coco format")
parser.add_argument(
'--use_coco_category',
action='store_true',
default=False,
help='Whether to use the coco format dictionary `clsid2catid`')
parser.add_argument( parser.add_argument(
"--slice_infer", "--slice_infer",
action='store_true', action='store_true',
...@@ -386,3 +392,87 @@ def nms(dets, match_threshold=0.6, match_metric='iou'): ...@@ -386,3 +392,87 @@ def nms(dets, match_threshold=0.6, match_metric='iou'):
keep = np.where(suppressed == 0)[0] keep = np.where(suppressed == 0)[0]
dets = dets[keep, :] dets = dets[keep, :]
return dets return dets
coco_clsid2catid = {
0: 1,
1: 2,
2: 3,
3: 4,
4: 5,
5: 6,
6: 7,
7: 8,
8: 9,
9: 10,
10: 11,
11: 13,
12: 14,
13: 15,
14: 16,
15: 17,
16: 18,
17: 19,
18: 20,
19: 21,
20: 22,
21: 23,
22: 24,
23: 25,
24: 27,
25: 28,
26: 31,
27: 32,
28: 33,
29: 34,
30: 35,
31: 36,
32: 37,
33: 38,
34: 39,
35: 40,
36: 41,
37: 42,
38: 43,
39: 44,
40: 46,
41: 47,
42: 48,
43: 49,
44: 50,
45: 51,
46: 52,
47: 53,
48: 54,
49: 55,
50: 56,
51: 57,
52: 58,
53: 59,
54: 60,
55: 61,
56: 62,
57: 63,
58: 64,
59: 65,
60: 67,
61: 70,
62: 72,
63: 73,
64: 74,
65: 75,
66: 76,
67: 77,
68: 78,
69: 79,
70: 80,
71: 81,
72: 82,
73: 84,
74: 85,
75: 86,
76: 87,
77: 88,
78: 89,
79: 90
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册