未验证 提交 90d4a065 编写于 作者: W wangxinxin08 提交者: GitHub

modify VOCDataSet and default value of allow_empty (#4096)

上级 e4982704
...@@ -48,7 +48,7 @@ class COCODataSet(DetDataset): ...@@ -48,7 +48,7 @@ class COCODataSet(DetDataset):
data_fields=['image'], data_fields=['image'],
sample_num=-1, sample_num=-1,
load_crowd=False, load_crowd=False,
allow_empty=False, allow_empty=True,
empty_ratio=1.): empty_ratio=1.):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path, super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num) data_fields, sample_num)
...@@ -243,7 +243,7 @@ class COCODataSet(DetDataset): ...@@ -243,7 +243,7 @@ class COCODataSet(DetDataset):
break break
assert ct > 0, 'not found any coco record in %s' % (anno_path) assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0: if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records)) empty_records = self._sample_empty(empty_records, len(records))
records += empty_records records += empty_records
self.roidbs = records self.roidbs = records
...@@ -55,7 +55,7 @@ class VOCDataSet(DetDataset): ...@@ -55,7 +55,7 @@ class VOCDataSet(DetDataset):
data_fields=['image'], data_fields=['image'],
sample_num=-1, sample_num=-1,
label_list=None, label_list=None,
allow_empty=False, allow_empty=True,
empty_ratio=1.): empty_ratio=1.):
super(VOCDataSet, self).__init__( super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
...@@ -131,11 +131,13 @@ class VOCDataSet(DetDataset): ...@@ -131,11 +131,13 @@ class VOCDataSet(DetDataset):
'Illegal width: {} or height: {} in annotation, ' 'Illegal width: {} or height: {} in annotation, '
'and {} will be ignored'.format(im_w, im_h, xml_file)) 'and {} will be ignored'.format(im_w, im_h, xml_file))
continue continue
gt_bbox = []
gt_class = [] num_bbox, i = len(objs), 0
gt_score = [] gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
difficult = [] gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
for i, obj in enumerate(objs): gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
for obj in objs:
cname = obj.find('name').text cname = obj.find('name').text
# user dataset may not contain difficult field # user dataset may not contain difficult field
...@@ -152,19 +154,20 @@ class VOCDataSet(DetDataset): ...@@ -152,19 +154,20 @@ class VOCDataSet(DetDataset):
x2 = min(im_w - 1, x2) x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2) y2 = min(im_h - 1, y2)
if x2 > x1 and y2 > y1: if x2 > x1 and y2 > y1:
gt_bbox.append([x1, y1, x2, y2]) gt_bbox[i, :] = [x1, y1, x2, y2]
gt_class.append([cname2cid[cname]]) gt_class[i, 0] = cname2cid[cname]
gt_score.append([1.]) gt_score[i, 0] = 1.
difficult.append([_difficult]) difficult[i, 0] = _difficult
i += 1
else: else:
logger.warning( logger.warning(
'Found an invalid bbox in annotations: xml_file: {}' 'Found an invalid bbox in annotations: xml_file: {}'
', x1: {}, y1: {}, x2: {}, y2: {}.'.format( ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
xml_file, x1, y1, x2, y2)) xml_file, x1, y1, x2, y2))
gt_bbox = np.array(gt_bbox).astype('float32') gt_bbox = gt_bbox[:i, :]
gt_class = np.array(gt_class).astype('int32') gt_class = gt_class[:i, :]
gt_score = np.array(gt_score).astype('float32') gt_score = gt_score[:i, :]
difficult = np.array(difficult).astype('int32') difficult = difficult[:i, :]
voc_rec = { voc_rec = {
'im_file': img_file, 'im_file': img_file,
...@@ -193,7 +196,7 @@ class VOCDataSet(DetDataset): ...@@ -193,7 +196,7 @@ class VOCDataSet(DetDataset):
break break
assert ct > 0, 'not found any voc record in %s' % (self.anno_path) assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0: if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records)) empty_records = self._sample_empty(empty_records, len(records))
records += empty_records records += empty_records
self.roidbs, self.cname2cid = records, cname2cid self.roidbs, self.cname2cid = records, cname2cid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册