From 90d4a0659bf3e3b806b8f7fa7e081142715b1e2c Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 2 Sep 2021 14:21:41 +0800 Subject: [PATCH] modify VOCDataSet and default value of allow_empty (#4096) --- ppdet/data/source/coco.py | 4 ++-- ppdet/data/source/voc.py | 33 ++++++++++++++++++--------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index a3cc3e0c4..6d8fdcc9e 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -48,7 +48,7 @@ class COCODataSet(DetDataset): data_fields=['image'], sample_num=-1, load_crowd=False, - allow_empty=False, + allow_empty=True, empty_ratio=1.): super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path, data_fields, sample_num) @@ -243,7 +243,7 @@ class COCODataSet(DetDataset): break assert ct > 0, 'not found any coco record in %s' % (anno_path) logger.debug('{} samples in file {}'.format(ct, anno_path)) - if len(empty_records) > 0: + if self.allow_empty and len(empty_records) > 0: empty_records = self._sample_empty(empty_records, len(records)) records += empty_records self.roidbs = records diff --git a/ppdet/data/source/voc.py b/ppdet/data/source/voc.py index eeef989f7..b836b4d05 100644 --- a/ppdet/data/source/voc.py +++ b/ppdet/data/source/voc.py @@ -55,7 +55,7 @@ class VOCDataSet(DetDataset): data_fields=['image'], sample_num=-1, label_list=None, - allow_empty=False, + allow_empty=True, empty_ratio=1.): super(VOCDataSet, self).__init__( dataset_dir=dataset_dir, @@ -131,11 +131,13 @@ class VOCDataSet(DetDataset): 'Illegal width: {} or height: {} in annotation, ' 'and {} will be ignored'.format(im_w, im_h, xml_file)) continue - gt_bbox = [] - gt_class = [] - gt_score = [] - difficult = [] - for i, obj in enumerate(objs): + + num_bbox, i = len(objs), 0 + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.zeros((num_bbox, 1), dtype=np.float32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + for obj in objs: cname = obj.find('name').text # user dataset may not contain difficult field @@ -152,19 +154,20 @@ class VOCDataSet(DetDataset): x2 = min(im_w - 1, x2) y2 = min(im_h - 1, y2) if x2 > x1 and y2 > y1: - gt_bbox.append([x1, y1, x2, y2]) - gt_class.append([cname2cid[cname]]) - gt_score.append([1.]) - difficult.append([_difficult]) + gt_bbox[i, :] = [x1, y1, x2, y2] + gt_class[i, 0] = cname2cid[cname] + gt_score[i, 0] = 1. + difficult[i, 0] = _difficult + i += 1 else: logger.warning( 'Found an invalid bbox in annotations: xml_file: {}' ', x1: {}, y1: {}, x2: {}, y2: {}.'.format( xml_file, x1, y1, x2, y2)) - gt_bbox = np.array(gt_bbox).astype('float32') - gt_class = np.array(gt_class).astype('int32') - gt_score = np.array(gt_score).astype('float32') - difficult = np.array(difficult).astype('int32') + gt_bbox = gt_bbox[:i, :] + gt_class = gt_class[:i, :] + gt_score = gt_score[:i, :] + difficult = difficult[:i, :] voc_rec = { 'im_file': img_file, @@ -193,7 +196,7 @@ class VOCDataSet(DetDataset): break assert ct > 0, 'not found any voc record in %s' % (self.anno_path) logger.debug('{} samples in file {}'.format(ct, anno_path)) - if len(empty_records) > 0: + if self.allow_empty and len(empty_records) > 0: empty_records = self._sample_empty(empty_records, len(records)) records += empty_records self.roidbs, self.cname2cid = records, cname2cid -- GitLab