diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index c4fbcfc40c293d7a80a0402d6aa981273b56da87..6650e56e9955796897cce9f66ffaf298a76cfaae 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -242,7 +242,7 @@ class COCODataSet(DetDataset): break assert ct > 0, 'not found any coco record in %s' % (anno_path) logger.debug('{} samples in file {}'.format(ct, anno_path)) - if len(empty_records) > 0: + if self.allow_empty and len(empty_records) > 0: empty_records = self._sample_empty(empty_records, len(records)) records += empty_records self.roidbs = records diff --git a/ppdet/data/source/voc.py b/ppdet/data/source/voc.py index eeef989f7f2355a5c5be12615203dcd9d4ac6299..1c2a7ef98ccbac760430befc375a79cdebc51a7c 100644 --- a/ppdet/data/source/voc.py +++ b/ppdet/data/source/voc.py @@ -131,11 +131,13 @@ class VOCDataSet(DetDataset): 'Illegal width: {} or height: {} in annotation, ' 'and {} will be ignored'.format(im_w, im_h, xml_file)) continue - gt_bbox = [] - gt_class = [] - gt_score = [] - difficult = [] - for i, obj in enumerate(objs): + + num_bbox, i = len(objs), 0 + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.zeros((num_bbox, 1), dtype=np.float32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + for obj in objs: cname = obj.find('name').text # user dataset may not contain difficult field @@ -152,19 +154,20 @@ class VOCDataSet(DetDataset): x2 = min(im_w - 1, x2) y2 = min(im_h - 1, y2) if x2 > x1 and y2 > y1: - gt_bbox.append([x1, y1, x2, y2]) - gt_class.append([cname2cid[cname]]) - gt_score.append([1.]) - difficult.append([_difficult]) + gt_bbox[i, :] = [x1, y1, x2, y2] + gt_class[i, 0] = cname2cid[cname] + gt_score[i, 0] = 1. + difficult[i, 0] = _difficult + i += 1 else: logger.warning( 'Found an invalid bbox in annotations: xml_file: {}' ', x1: {}, y1: {}, x2: {}, y2: {}.'.format( xml_file, x1, y1, x2, y2)) - gt_bbox = np.array(gt_bbox).astype('float32') - gt_class = np.array(gt_class).astype('int32') - gt_score = np.array(gt_score).astype('float32') - difficult = np.array(difficult).astype('int32') + gt_bbox = gt_bbox[:i, :] + gt_class = gt_class[:i, :] + gt_score = gt_score[:i, :] + difficult = difficult[:i, :] voc_rec = { 'im_file': img_file, @@ -193,7 +196,7 @@ class VOCDataSet(DetDataset): break assert ct > 0, 'not found any voc record in %s' % (self.anno_path) logger.debug('{} samples in file {}'.format(ct, anno_path)) - if len(empty_records) > 0: + if self.allow_empty and len(empty_records) > 0: empty_records = self._sample_empty(empty_records, len(records)) records += empty_records self.roidbs, self.cname2cid = records, cname2cid