未验证 提交 06c8cf7e 编写于 作者: W wangguanzhong 提交者: GitHub

fix voc save_result in infer (#6547)

上级 27930651
...@@ -208,6 +208,10 @@ class ImageFolder(DetDataset): ...@@ -208,6 +208,10 @@ class ImageFolder(DetDataset):
self.image_dir = images self.image_dir = images
self.roidbs = self._load_images() self.roidbs = self._load_images()
def get_label_list(self):
# Only VOC dataset needs label list in ImageFold
return self.anno_path
@register @register
class CommonDataset(object): class CommonDataset(object):
......
...@@ -287,12 +287,18 @@ class Trainer(object): ...@@ -287,12 +287,18 @@ class Trainer(object):
save_prediction_only=save_prediction_only) save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'VOC': elif self.cfg.metric == 'VOC':
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [ self._metrics = [
VOCMetric( VOCMetric(
label_list=self.dataset.get_label_list(), label_list=self.dataset.get_label_list(),
class_num=self.cfg.num_classes, class_num=self.cfg.num_classes,
map_type=self.cfg.map_type, map_type=self.cfg.map_type,
classwise=classwise) classwise=classwise,
output_eval=output_eval,
save_prediction_only=save_prediction_only)
] ]
elif self.cfg.metric == 'WiderFace': elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
......
...@@ -225,7 +225,9 @@ class VOCMetric(Metric): ...@@ -225,7 +225,9 @@ class VOCMetric(Metric):
map_type='11point', map_type='11point',
is_bbox_normalized=False, is_bbox_normalized=False,
evaluate_difficult=False, evaluate_difficult=False,
classwise=False): classwise=False,
output_eval=None,
save_prediction_only=False):
assert os.path.isfile(label_list), \ assert os.path.isfile(label_list), \
"label_list {} not a file".format(label_list) "label_list {} not a file".format(label_list)
self.clsid2catid, self.catid2name = get_categories('VOC', label_list) self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
...@@ -233,6 +235,8 @@ class VOCMetric(Metric): ...@@ -233,6 +235,8 @@ class VOCMetric(Metric):
self.overlap_thresh = overlap_thresh self.overlap_thresh = overlap_thresh
self.map_type = map_type self.map_type = map_type
self.evaluate_difficult = evaluate_difficult self.evaluate_difficult = evaluate_difficult
self.output_eval = output_eval
self.save_prediction_only = save_prediction_only
self.detection_map = DetectionMAP( self.detection_map = DetectionMAP(
class_num=class_num, class_num=class_num,
overlap_thresh=overlap_thresh, overlap_thresh=overlap_thresh,
...@@ -245,6 +249,7 @@ class VOCMetric(Metric): ...@@ -245,6 +249,7 @@ class VOCMetric(Metric):
self.reset() self.reset()
def reset(self): def reset(self):
self.results = {'bbox': [], 'score': [], 'label': []}
self.detection_map.reset() self.detection_map.reset()
def update(self, inputs, outputs): def update(self, inputs, outputs):
...@@ -256,8 +261,15 @@ class VOCMetric(Metric): ...@@ -256,8 +261,15 @@ class VOCMetric(Metric):
bbox_lengths = outputs['bbox_num'].numpy() if isinstance( bbox_lengths = outputs['bbox_num'].numpy() if isinstance(
outputs['bbox_num'], paddle.Tensor) else outputs['bbox_num'] outputs['bbox_num'], paddle.Tensor) else outputs['bbox_num']
self.results['bbox'].append(bboxes.tolist())
self.results['score'].append(scores.tolist())
self.results['label'].append(labels.tolist())
if bboxes.shape == (1, 1) or bboxes is None: if bboxes.shape == (1, 1) or bboxes is None:
return return
if self.save_prediction_only:
return
gt_boxes = inputs['gt_bbox'] gt_boxes = inputs['gt_bbox']
gt_labels = inputs['gt_class'] gt_labels = inputs['gt_class']
difficults = inputs['difficult'] if not self.evaluate_difficult \ difficults = inputs['difficult'] if not self.evaluate_difficult \
...@@ -294,6 +306,15 @@ class VOCMetric(Metric): ...@@ -294,6 +306,15 @@ class VOCMetric(Metric):
bbox_idx += bbox_num bbox_idx += bbox_num
def accumulate(self): def accumulate(self):
output = "bbox.json"
if self.output_eval:
output = os.path.join(self.output_eval, output)
with open(output, 'w') as f:
json.dump(self.results, f)
logger.info('The bbox result is saved to bbox.json.')
if self.save_prediction_only:
return
logger.info("Accumulating evaluatation results...") logger.info("Accumulating evaluatation results...")
self.detection_map.accumulate() self.detection_map.accumulate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册