未验证 提交 484fcbbe 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix yolov3 voc eval in train (#3053)

* fix yolov3 voc eval in train

* fix comment
上级 f8b2d8a4
......@@ -936,6 +936,7 @@ class YoloEvalFeed(DataFeed):
with_background=with_background,
num_workers=num_workers,
use_process=use_process)
self.num_max_boxes = num_max_boxes
self.mode = 'VAL'
self.bufsize = 128
......
......@@ -245,8 +245,8 @@ class ArrangeEvalYOLO(BaseOperator):
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, gt_bbox, gt_class, gt_score,
is_crowd, im_info, gt_masks)
(image, im_shape, im_id, gt_bbox, gt_class,
difficult)
"""
im = sample['image']
if len(sample['gt_bbox']) != len(sample['gt_class']):
......@@ -255,9 +255,14 @@ class ArrangeEvalYOLO(BaseOperator):
h = sample['h']
w = sample['w']
im_shape = np.array((h, w))
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
gt_bbox = np.zeros((50, 4), dtype=im.dtype)
gt_class = np.zeros((50, ), dtype=np.int32)
difficult = np.zeros((50, ), dtype=np.int32)
gt_num = min(50, len(sample['gt_bbox']))
if gt_num > 0:
gt_bbox[:gt_num, :] = sample['gt_bbox'][:gt_num, :]
gt_class[:gt_num] = sample['gt_class'][:gt_num, 0]
difficult[:gt_num] = sample['difficult'][:gt_num, 0]
outs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
return outs
......
......@@ -54,9 +54,11 @@ def create_feed(feed, use_pyreader=True):
feed_var_map['gt_label']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_score']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_box']['shape'] = [feed.num_max_boxes, 4]
feed_var_map['is_difficult']['shape'] = [feed.num_max_boxes]
feed_var_map['gt_label']['lod_level'] = 0
feed_var_map['gt_score']['lod_level'] = 0
feed_var_map['gt_box']['lod_level'] = 0
feed_var_map['is_difficult']['lod_level'] = 0
feed_vars = OrderedDict([(key, fluid.layers.data(
name=feed_var_map[key]['name'],
......
......@@ -105,7 +105,7 @@ class DetectionMAP(object):
# record class gt count
for gtl, diff in zip(gt_label, difficult):
if self.evaluate_difficult or int(diff) == 0:
self.class_gt_counts[int(gtl[0])] += 1
self.class_gt_counts[int(np.array(gtl))] += 1
# record class score positive
visited = [False] * len(gt_label)
......@@ -124,7 +124,7 @@ class DetectionMAP(object):
if max_overlap > self.overlap_thresh:
if self.evaluate_difficult or \
int(difficult[max_idx]) == 0:
int(np.array(difficult[max_idx])) == 0:
if not visited[max_idx]:
self.class_score_poss[
int(label)].append([score, 1.0])
......
......@@ -71,14 +71,26 @@ def bbox_eval(results,
continue
gt_boxes = t['gt_box'][0]
gt_box_lengths = t['gt_box'][1][0]
gt_labels = t['gt_label'][0]
assert len(gt_boxes) == len(gt_labels)
difficults = t['is_difficult'][0] if not evaluate_difficult \
else None
if not evaluate_difficult:
assert len(gt_labels) == len(difficults)
if len(t['gt_box'][1]) == 0:
# gt_box, gt_label, difficult read as zero padded Tensor
bbox_idx = 0
for i in range(len(gt_boxes)):
gt_box = gt_boxes[i]
gt_label = gt_labels[i]
difficult = difficults[i]
bbox_num = bbox_lengths[i]
bbox = bboxes[bbox_idx: bbox_idx + bbox_num]
gt_box, gt_label, difficult = prune_zero_padding(
gt_box, gt_label, difficult)
detection_map.update(bbox, gt_box, gt_label, difficult)
bbox_idx += bbox_num
else:
# gt_box, gt_label, difficult read as LoDTensor
gt_box_lengths = t['gt_box'][1][0]
bbox_idx = 0
gt_box_idx = 0
for i in range(len(bbox_lengths)):
......@@ -99,6 +111,17 @@ def bbox_eval(results,
map_type, 100. * detection_map.get_map()))
def prune_zero_padding(gt_box, gt_label, difficult=None):
valid_cnt = 0
for i in range(len(gt_box)):
if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
break
valid_cnt += 1
return (gt_box[:valid_cnt], gt_label[:valid_cnt],
difficult[:valid_cnt] if difficult is not None else None)
def get_category_info(anno_file=None,
with_background=True,
use_default_label=False):
......
......@@ -123,8 +123,11 @@ def main():
eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse eval fetches
extra_keys = ['im_info', 'im_id',
'im_shape'] if cfg.metric == 'COCO' else []
extra_keys = []
if cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg.metric == 'VOC':
extra_keys = ['gt_box', 'gt_label', 'is_difficult']
eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog,
extra_keys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册