voc.py 8.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np

import xml.etree.ElementTree as ET

from ppdet.core.workspace import register, serializable

from .dataset import DetDataset

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)


@register
@serializable
class VOCDataSet(DetDataset):
    """
    Load dataset with PascalVOC format.

    Notes:
    `anno_path` must contains xml file and image file path for annotations.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): voc annotation file path.
F
Feng Ni 已提交
41
        data_fields (list): key name of data dictionary, at least have 'image'.
Q
qingqing01 已提交
42 43 44
        sample_num (int): number of samples to load, -1 means all.
        label_list (str): if use_default_label is False, will load
            mapping between category and class index.
45 46 47 48
        allow_empty (bool): whether to load empty entry. False as default
        empty_ratio (float): the ratio of empty record number to total 
            record's, if empty_ratio is out of [0. ,1.), do not sample the 
            records and use all the empty entries. 1. as default
Q
qingqing01 已提交
49 50 51 52 53 54 55 56
    """

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 data_fields=['image'],
                 sample_num=-1,
57
                 label_list=None,
58
                 allow_empty=False,
59
                 empty_ratio=1.):
Q
qingqing01 已提交
60 61 62 63 64 65 66
        super(VOCDataSet, self).__init__(
            dataset_dir=dataset_dir,
            image_dir=image_dir,
            anno_path=anno_path,
            data_fields=data_fields,
            sample_num=sample_num)
        self.label_list = label_list
67 68 69 70 71 72 73 74 75 76 77 78
        self.allow_empty = allow_empty
        self.empty_ratio = empty_ratio

    def _sample_empty(self, records, num):
        # if empty_ratio is out of [0. ,1.), do not sample the records
        if self.empty_ratio < 0. or self.empty_ratio >= 1.:
            return records
        import random
        sample_num = min(
            int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
        records = random.sample(records, sample_num)
        return records
Q
qingqing01 已提交
79

80
    def parse_dataset(self, ):
Q
qingqing01 已提交
81 82 83 84
        anno_path = os.path.join(self.dataset_dir, self.anno_path)
        image_dir = os.path.join(self.dataset_dir, self.image_dir)

        # mapping category name to class id
85
        # first_class:0, second_class:1, ...
Q
qingqing01 已提交
86
        records = []
87
        empty_records = []
Q
qingqing01 已提交
88 89 90 91 92 93 94 95
        ct = 0
        cname2cid = {}
        if self.label_list:
            label_path = os.path.join(self.dataset_dir, self.label_list)
            if not os.path.exists(label_path):
                raise ValueError("label_list {} does not exists".format(
                    label_path))
            with open(label_path, 'r') as fr:
96
                label_id = 0
Q
qingqing01 已提交
97 98 99 100
                for line in fr.readlines():
                    cname2cid[line.strip()] = label_id
                    label_id += 1
        else:
101
            cname2cid = pascalvoc_label()
Q
qingqing01 已提交
102 103 104 105 106 107 108 109 110

        with open(anno_path, 'r') as fr:
            while True:
                line = fr.readline()
                if not line:
                    break
                img_file, xml_file = [os.path.join(image_dir, x) \
                        for x in line.strip().split()[:2]]
                if not os.path.exists(img_file):
111
                    logger.warning(
Q
qingqing01 已提交
112 113 114 115
                        'Illegal image file: {}, and it will be ignored'.format(
                            img_file))
                    continue
                if not os.path.isfile(xml_file):
116 117 118
                    logger.warning(
                        'Illegal xml file: {}, and it will be ignored'.format(
                            xml_file))
Q
qingqing01 已提交
119 120 121 122 123 124 125 126 127 128 129
                    continue
                tree = ET.parse(xml_file)
                if tree.find('id') is None:
                    im_id = np.array([ct])
                else:
                    im_id = np.array([int(tree.find('id').text)])

                objs = tree.findall('object')
                im_w = float(tree.find('size').find('width').text)
                im_h = float(tree.find('size').find('height').text)
                if im_w < 0 or im_h < 0:
130
                    logger.warning(
Q
qingqing01 已提交
131 132 133
                        'Illegal width: {} or height: {} in annotation, '
                        'and {} will be ignored'.format(im_w, im_h, xml_file))
                    continue
134 135 136 137 138 139 140

                num_bbox, i = len(objs), 0
                gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
                gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
                gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
                difficult = np.zeros((num_bbox, 1), dtype=np.int32)
                for obj in objs:
Q
qingqing01 已提交
141
                    cname = obj.find('name').text
142 143 144 145 146 147

                    # user dataset may not contain difficult field
                    _difficult = obj.find('difficult')
                    _difficult = int(
                        _difficult.text) if _difficult is not None else 0

Q
qingqing01 已提交
148 149 150 151 152 153 154 155 156
                    x1 = float(obj.find('bndbox').find('xmin').text)
                    y1 = float(obj.find('bndbox').find('ymin').text)
                    x2 = float(obj.find('bndbox').find('xmax').text)
                    y2 = float(obj.find('bndbox').find('ymax').text)
                    x1 = max(0, x1)
                    y1 = max(0, y1)
                    x2 = min(im_w - 1, x2)
                    y2 = min(im_h - 1, y2)
                    if x2 > x1 and y2 > y1:
157 158 159 160 161
                        gt_bbox[i, :] = [x1, y1, x2, y2]
                        gt_class[i, 0] = cname2cid[cname]
                        gt_score[i, 0] = 1.
                        difficult[i, 0] = _difficult
                        i += 1
Q
qingqing01 已提交
162
                    else:
163
                        logger.warning(
Q
qingqing01 已提交
164 165 166
                            'Found an invalid bbox in annotations: xml_file: {}'
                            ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                xml_file, x1, y1, x2, y2))
167 168 169 170
                gt_bbox = gt_bbox[:i, :]
                gt_class = gt_class[:i, :]
                gt_score = gt_score[:i, :]
                difficult = difficult[:i, :]
Q
qingqing01 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

                voc_rec = {
                    'im_file': img_file,
                    'im_id': im_id,
                    'h': im_h,
                    'w': im_w
                } if 'image' in self.data_fields else {}

                gt_rec = {
                    'gt_class': gt_class,
                    'gt_score': gt_score,
                    'gt_bbox': gt_bbox,
                    'difficult': difficult
                }
                for k, v in gt_rec.items():
                    if k in self.data_fields:
                        voc_rec[k] = v

189 190 191
                if len(objs) == 0:
                    empty_records.append(voc_rec)
                else:
Q
qingqing01 已提交
192 193 194 195 196
                    records.append(voc_rec)

                ct += 1
                if self.sample_num > 0 and ct >= self.sample_num:
                    break
197
        assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
Q
qingqing01 已提交
198
        logger.debug('{} samples in file {}'.format(ct, anno_path))
199
        if self.allow_empty and len(empty_records) > 0:
200 201
            empty_records = self._sample_empty(empty_records, len(records))
            records += empty_records
Q
qingqing01 已提交
202 203 204 205 206 207
        self.roidbs, self.cname2cid = records, cname2cid

    def get_label_list(self):
        return os.path.join(self.dataset_dir, self.label_list)


208
def pascalvoc_label():
Q
qingqing01 已提交
209
    labels_map = {
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
        'aeroplane': 0,
        'bicycle': 1,
        'bird': 2,
        'boat': 3,
        'bottle': 4,
        'bus': 5,
        'car': 6,
        'cat': 7,
        'chair': 8,
        'cow': 9,
        'diningtable': 10,
        'dog': 11,
        'horse': 12,
        'motorbike': 13,
        'person': 14,
        'pottedplant': 15,
        'sheep': 16,
        'sofa': 17,
        'train': 18,
        'tvmonitor': 19
Q
qingqing01 已提交
230 231
    }
    return labels_map