voc.py 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np

import xml.etree.ElementTree as ET

from ppdet.core.workspace import register, serializable

K
Kaipeng Deng 已提交
22
from .dataset import DetDataset
G
Guanghua Yu 已提交
23 24
import logging
logger = logging.getLogger(__name__)
25 26 27 28


@register
@serializable
K
Kaipeng Deng 已提交
29
class VOCDataSet(DetDataset):
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    """
    Load dataset with PascalVOC format.

    Notes:
    `anno_path` must contains xml file and image file path for annotations.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): voc annotation file path.
        sample_num (int): number of samples to load, -1 means all.
        label_list (str): if use_default_label is False, will load
            mapping between category and class index.
    """

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 sample_num=-1,
K
Kaipeng Deng 已提交
50
                 label_list=None):
51
        super(VOCDataSet, self).__init__(
K
Kaipeng Deng 已提交
52
            dataset_dir=dataset_dir,
53 54
            image_dir=image_dir,
            anno_path=anno_path,
K
Kaipeng Deng 已提交
55
            sample_num=sample_num)
56 57
        self.label_list = label_list

K
Kaipeng Deng 已提交
58
    def parse_dataset(self, with_background=True):
59 60 61 62 63 64 65 66 67 68 69
        anno_path = os.path.join(self.dataset_dir, self.anno_path)
        image_dir = os.path.join(self.dataset_dir, self.image_dir)

        # mapping category name to class id
        # if with_background is True:
        #   background:0, first_class:1, second_class:2, ...
        # if with_background is False:
        #   first_class:0, second_class:1, ...
        records = []
        ct = 0
        cname2cid = {}
K
Kaipeng Deng 已提交
70
        if self.label_list:
71 72 73 74 75
            label_path = os.path.join(self.dataset_dir, self.label_list)
            if not os.path.exists(label_path):
                raise ValueError("label_list {} does not exists".format(
                    label_path))
            with open(label_path, 'r') as fr:
K
Kaipeng Deng 已提交
76
                label_id = int(with_background)
77 78 79 80
                for line in fr.readlines():
                    cname2cid[line.strip()] = label_id
                    label_id += 1
        else:
K
Kaipeng Deng 已提交
81
            cname2cid = pascalvoc_label(with_background)
82 83 84 85 86 87 88 89

        with open(anno_path, 'r') as fr:
            while True:
                line = fr.readline()
                if not line:
                    break
                img_file, xml_file = [os.path.join(image_dir, x) \
                        for x in line.strip().split()[:2]]
W
wangguanzhong 已提交
90 91 92 93 94
                if not os.path.exists(img_file):
                    logger.warn(
                        'Illegal image file: {}, and it will be ignored'.format(
                            img_file))
                    continue
95
                if not os.path.isfile(xml_file):
W
wangguanzhong 已提交
96 97
                    logger.warn('Illegal xml file: {}, and it will be ignored'.
                                format(xml_file))
98 99 100 101 102 103 104 105 106 107
                    continue
                tree = ET.parse(xml_file)
                if tree.find('id') is None:
                    im_id = np.array([ct])
                else:
                    im_id = np.array([int(tree.find('id').text)])

                objs = tree.findall('object')
                im_w = float(tree.find('size').find('width').text)
                im_h = float(tree.find('size').find('height').text)
W
wangguanzhong 已提交
108 109 110 111 112 113 114 115 116 117
                if im_w < 0 or im_h < 0:
                    logger.warn(
                        'Illegal width: {} or height: {} in annotation, '
                        'and {} will be ignored'.format(im_w, im_h, xml_file))
                    continue
                gt_bbox = []
                gt_class = []
                gt_score = []
                is_crowd = []
                difficult = []
118 119 120 121 122 123 124 125 126 127 128
                for i, obj in enumerate(objs):
                    cname = obj.find('name').text
                    _difficult = int(obj.find('difficult').text)
                    x1 = float(obj.find('bndbox').find('xmin').text)
                    y1 = float(obj.find('bndbox').find('ymin').text)
                    x2 = float(obj.find('bndbox').find('xmax').text)
                    y2 = float(obj.find('bndbox').find('ymax').text)
                    x1 = max(0, x1)
                    y1 = max(0, y1)
                    x2 = min(im_w - 1, x2)
                    y2 = min(im_h - 1, y2)
W
wangguanzhong 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
                    if x2 > x1 and y2 > y1:
                        gt_bbox.append([x1, y1, x2, y2])
                        gt_class.append([cname2cid[cname]])
                        gt_score.append([1.])
                        is_crowd.append([0])
                        difficult.append([_difficult])
                    else:
                        logger.warn(
                            'Found an invalid bbox in annotations: xml_file: {}'
                            ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                xml_file, x1, y1, x2, y2))
                gt_bbox = np.array(gt_bbox).astype('float32')
                gt_class = np.array(gt_class).astype('int32')
                gt_score = np.array(gt_score).astype('float32')
                is_crowd = np.array(is_crowd).astype('int32')
                difficult = np.array(difficult).astype('int32')
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
                voc_rec = {
                    'im_file': img_file,
                    'im_id': im_id,
                    'h': im_h,
                    'w': im_w,
                    'is_crowd': is_crowd,
                    'gt_class': gt_class,
                    'gt_score': gt_score,
                    'gt_bbox': gt_bbox,
                    'difficult': difficult
                }
                if len(objs) != 0:
                    records.append(voc_rec)

                ct += 1
                if self.sample_num > 0 and ct >= self.sample_num:
                    break
        assert len(records) > 0, 'not found any voc record in %s' % (
            self.anno_path)
Y
Yang Zhang 已提交
164
        logger.debug('{} samples in file {}'.format(ct, anno_path))
165 166
        self.roidbs, self.cname2cid = records, cname2cid

K
Kaipeng Deng 已提交
167 168 169
    def get_label_list(self):
        return os.path.join(self.dataset_dir, self.label_list)

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

def pascalvoc_label(with_background=True):
    labels_map = {
        'aeroplane': 1,
        'bicycle': 2,
        'bird': 3,
        'boat': 4,
        'bottle': 5,
        'bus': 6,
        'car': 7,
        'cat': 8,
        'chair': 9,
        'cow': 10,
        'diningtable': 11,
        'dog': 12,
        'horse': 13,
        'motorbike': 14,
        'person': 15,
        'pottedplant': 16,
        'sheep': 17,
        'sofa': 18,
        'train': 19,
        'tvmonitor': 20
    }
    if not with_background:
        labels_map = {k: v - 1 for k, v in labels_map.items()}
    return labels_map