未验证 提交 06c8cf7e 编写于 作者: W wangguanzhong 提交者: GitHub

fix voc save_result in infer (#6547)

上级 27930651
......@@ -208,6 +208,10 @@ class ImageFolder(DetDataset):
self.image_dir = images
self.roidbs = self._load_images()
def get_label_list(self):
# Only VOC dataset needs label list in ImageFold
return self.anno_path
@register
class CommonDataset(object):
......
......@@ -287,12 +287,18 @@ class Trainer(object):
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'VOC':
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [
VOCMetric(
label_list=self.dataset.get_label_list(),
class_num=self.cfg.num_classes,
map_type=self.cfg.map_type,
classwise=classwise)
classwise=classwise,
output_eval=output_eval,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
......
......@@ -225,7 +225,9 @@ class VOCMetric(Metric):
map_type='11point',
is_bbox_normalized=False,
evaluate_difficult=False,
classwise=False):
classwise=False,
output_eval=None,
save_prediction_only=False):
assert os.path.isfile(label_list), \
"label_list {} not a file".format(label_list)
self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
......@@ -233,6 +235,8 @@ class VOCMetric(Metric):
self.overlap_thresh = overlap_thresh
self.map_type = map_type
self.evaluate_difficult = evaluate_difficult
self.output_eval = output_eval
self.save_prediction_only = save_prediction_only
self.detection_map = DetectionMAP(
class_num=class_num,
overlap_thresh=overlap_thresh,
......@@ -245,6 +249,7 @@ class VOCMetric(Metric):
self.reset()
def reset(self):
self.results = {'bbox': [], 'score': [], 'label': []}
self.detection_map.reset()
def update(self, inputs, outputs):
......@@ -256,8 +261,15 @@ class VOCMetric(Metric):
bbox_lengths = outputs['bbox_num'].numpy() if isinstance(
outputs['bbox_num'], paddle.Tensor) else outputs['bbox_num']
self.results['bbox'].append(bboxes.tolist())
self.results['score'].append(scores.tolist())
self.results['label'].append(labels.tolist())
if bboxes.shape == (1, 1) or bboxes is None:
return
if self.save_prediction_only:
return
gt_boxes = inputs['gt_bbox']
gt_labels = inputs['gt_class']
difficults = inputs['difficult'] if not self.evaluate_difficult \
......@@ -294,6 +306,15 @@ class VOCMetric(Metric):
bbox_idx += bbox_num
def accumulate(self):
output = "bbox.json"
if self.output_eval:
output = os.path.join(self.output_eval, output)
with open(output, 'w') as f:
json.dump(self.results, f)
logger.info('The bbox result is saved to bbox.json.')
if self.save_prediction_only:
return
logger.info("Accumulating evaluatation results...")
self.detection_map.accumulate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册