未验证 提交 90d4a065 编写于 作者: W wangxinxin08 提交者: GitHub

modify VOCDataSet and default value of allow_empty (#4096)

上级 e4982704
......@@ -48,7 +48,7 @@ class COCODataSet(DetDataset):
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
allow_empty=True,
empty_ratio=1.):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num)
......@@ -243,7 +243,7 @@ class COCODataSet(DetDataset):
break
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
......@@ -55,7 +55,7 @@ class VOCDataSet(DetDataset):
data_fields=['image'],
sample_num=-1,
label_list=None,
allow_empty=False,
allow_empty=True,
empty_ratio=1.):
super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir,
......@@ -131,11 +131,13 @@ class VOCDataSet(DetDataset):
'Illegal width: {} or height: {} in annotation, '
'and {} will be ignored'.format(im_w, im_h, xml_file))
continue
gt_bbox = []
gt_class = []
gt_score = []
difficult = []
for i, obj in enumerate(objs):
num_bbox, i = len(objs), 0
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
for obj in objs:
cname = obj.find('name').text
# user dataset may not contain difficult field
......@@ -152,19 +154,20 @@ class VOCDataSet(DetDataset):
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
if x2 > x1 and y2 > y1:
gt_bbox.append([x1, y1, x2, y2])
gt_class.append([cname2cid[cname]])
gt_score.append([1.])
difficult.append([_difficult])
gt_bbox[i, :] = [x1, y1, x2, y2]
gt_class[i, 0] = cname2cid[cname]
gt_score[i, 0] = 1.
difficult[i, 0] = _difficult
i += 1
else:
logger.warning(
'Found an invalid bbox in annotations: xml_file: {}'
', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
xml_file, x1, y1, x2, y2))
gt_bbox = np.array(gt_bbox).astype('float32')
gt_class = np.array(gt_class).astype('int32')
gt_score = np.array(gt_score).astype('float32')
difficult = np.array(difficult).astype('int32')
gt_bbox = gt_bbox[:i, :]
gt_class = gt_class[:i, :]
gt_score = gt_score[:i, :]
difficult = difficult[:i, :]
voc_rec = {
'im_file': img_file,
......@@ -193,7 +196,7 @@ class VOCDataSet(DetDataset):
break
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs, self.cname2cid = records, cname2cid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册