提交 1292af43 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix yolov3 voc eval in train (#3053)

* fix yolov3 voc eval in train

* fix comment
上级 577b4ba9
...@@ -936,6 +936,7 @@ class YoloEvalFeed(DataFeed): ...@@ -936,6 +936,7 @@ class YoloEvalFeed(DataFeed):
with_background=with_background, with_background=with_background,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_process=use_process)
self.num_max_boxes = num_max_boxes
self.mode = 'VAL' self.mode = 'VAL'
self.bufsize = 128 self.bufsize = 128
......
...@@ -245,8 +245,8 @@ class ArrangeEvalYOLO(BaseOperator): ...@@ -245,8 +245,8 @@ class ArrangeEvalYOLO(BaseOperator):
context: a dict which contains additional info. context: a dict which contains additional info.
Returns: Returns:
sample: a tuple containing the following items: sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score, (image, im_shape, im_id, gt_bbox, gt_class,
is_crowd, im_info, gt_masks) difficult)
""" """
im = sample['image'] im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']): if len(sample['gt_bbox']) != len(sample['gt_class']):
...@@ -255,9 +255,14 @@ class ArrangeEvalYOLO(BaseOperator): ...@@ -255,9 +255,14 @@ class ArrangeEvalYOLO(BaseOperator):
h = sample['h'] h = sample['h']
w = sample['w'] w = sample['w']
im_shape = np.array((h, w)) im_shape = np.array((h, w))
gt_bbox = sample['gt_bbox'] gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = sample['gt_class'] gt_class = np.zeros((50, ), dtype=np.int32)
difficult = sample['difficult'] difficult = np.zeros((50, ), dtype=np.int32)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
difficult[:gt_num] = sample['difficult'][:gt_num, 0]
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult) outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
return outs return outs
......
...@@ -54,9 +54,11 @@ def create_feed(feed, use_pyreader=True): ...@@ -54,9 +54,11 @@ def create_feed(feed, use_pyreader=True):
feed_var_map['gt_label']['shape'] = [feed.num_max_boxes] feed_var_map['gt_label']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_score']['shape'] = [feed.num_max_boxes] feed_var_map['gt_score']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_box']['shape'] = [feed.num_max_boxes, 4] feed_var_map['gt_box']['shape'] = [feed.num_max_boxes, 4]
feed_var_map['is_difficult']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_label']['lod_level'] = 0 feed_var_map['gt_label']['lod_level'] = 0
feed_var_map['gt_score']['lod_level'] = 0 feed_var_map['gt_score']['lod_level'] = 0
feed_var_map['gt_box']['lod_level'] = 0 feed_var_map['gt_box']['lod_level'] = 0
feed_var_map['is_difficult']['lod_level'] = 0
feed_vars = OrderedDict([(key, fluid.layers.data( feed_vars = OrderedDict([(key, fluid.layers.data(
name=feed_var_map[key]['name'], name=feed_var_map[key]['name'],
......
...@@ -105,7 +105,7 @@ class DetectionMAP(object): ...@@ -105,7 +105,7 @@ class DetectionMAP(object):
# record class gt count # record class gt count
for gtl, diff in zip(gt_label, difficult): for gtl, diff in zip(gt_label, difficult):
if self.evaluate_difficult or int(diff) == 0: if self.evaluate_difficult or int(diff) == 0:
self.class_gt_counts[int(gtl[0])] += 1 self.class_gt_counts[int(np.array(gtl))] += 1
# record class score positive # record class score positive
visited = [False] * len(gt_label) visited = [False] * len(gt_label)
...@@ -124,7 +124,7 @@ class DetectionMAP(object): ...@@ -124,7 +124,7 @@ class DetectionMAP(object):
if max_overlap > self.overlap_thresh: if max_overlap > self.overlap_thresh:
if self.evaluate_difficult or \ if self.evaluate_difficult or \
int(difficult[max_idx]) == 0: int(np.array(difficult[max_idx])) == 0:
if not visited[max_idx]: if not visited[max_idx]:
self.class_score_poss[ self.class_score_poss[
int(label)].append([score, 1.0]) int(label)].append([score, 1.0])
......
...@@ -71,27 +71,39 @@ def bbox_eval(results, ...@@ -71,27 +71,39 @@ def bbox_eval(results,
continue continue
gt_boxes = t['gt_box'][0] gt_boxes = t['gt_box'][0]
gt_box_lengths = t['gt_box'][1][0]
gt_labels = t['gt_label'][0] gt_labels = t['gt_label'][0]
assert len(gt_boxes) == len(gt_labels)
difficults = t['is_difficult'][0] if not evaluate_difficult \ difficults = t['is_difficult'][0] if not evaluate_difficult \
else None else None
if not evaluate_difficult:
assert len(gt_labels) == len(difficults) if len(t['gt_box'][1]) == 0:
# gt_box, gt_label, difficult read as zero padded Tensor
bbox_idx = 0 bbox_idx = 0
gt_box_idx = 0 for i in range(len(gt_boxes)):
for i in range(len(bbox_lengths)): gt_box = gt_boxes[i]
bbox_num = bbox_lengths[i] gt_label = gt_labels[i]
gt_box_num = gt_box_lengths[i] difficult = difficults[i]
bbox = bboxes[bbox_idx: bbox_idx + bbox_num] bbox_num = bbox_lengths[i]
gt_box = gt_boxes[gt_box_idx: gt_box_idx + gt_box_num] bbox = bboxes[bbox_idx: bbox_idx + bbox_num]
gt_label = gt_labels[gt_box_idx: gt_box_idx + gt_box_num] gt_box, gt_label, difficult = prune_zero_padding(
difficult = None if difficults is None else \ gt_box, gt_label, difficult)
difficults[gt_box_idx: gt_box_idx + gt_box_num] detection_map.update(bbox, gt_box, gt_label, difficult)
detection_map.update(bbox, gt_box, gt_label, difficult) bbox_idx += bbox_num
bbox_idx += bbox_num else:
gt_box_idx += gt_box_num # gt_box, gt_label, difficult read as LoDTensor
gt_box_lengths = t['gt_box'][1][0]
bbox_idx = 0
gt_box_idx = 0
for i in range(len(bbox_lengths)):
bbox_num = bbox_lengths[i]
gt_box_num = gt_box_lengths[i]
bbox = bboxes[bbox_idx: bbox_idx + bbox_num]
gt_box = gt_boxes[gt_box_idx: gt_box_idx + gt_box_num]
gt_label = gt_labels[gt_box_idx: gt_box_idx + gt_box_num]
difficult = None if difficults is None else \
difficults[gt_box_idx: gt_box_idx + gt_box_num]
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
gt_box_idx += gt_box_num
logger.info("Accumulating evaluatation results...") logger.info("Accumulating evaluatation results...")
detection_map.accumulate() detection_map.accumulate()
...@@ -99,6 +111,17 @@ def bbox_eval(results, ...@@ -99,6 +111,17 @@ def bbox_eval(results,
map_type, 100. * detection_map.get_map())) map_type, 100. * detection_map.get_map()))
def prune_zero_padding(gt_box, gt_label, difficult=None):
valid_cnt = 0
for i in range(len(gt_box)):
if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
break
valid_cnt += 1
return (gt_box[:valid_cnt], gt_label[:valid_cnt],
difficult[:valid_cnt] if difficult is not None else None)
def get_category_info(anno_file=None, def get_category_info(anno_file=None,
with_background=True, with_background=True,
use_default_label=False): use_default_label=False):
......
...@@ -123,8 +123,11 @@ def main(): ...@@ -123,8 +123,11 @@ def main():
eval_pyreader.decorate_sample_list_generator(eval_reader, place) eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse eval fetches # parse eval fetches
extra_keys = ['im_info', 'im_id', extra_keys = []
'im_shape'] if cfg.metric == 'COCO' else [] if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys) extra_keys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册