voc.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np

import xml.etree.ElementTree as ET

from ppdet.core.workspace import register, serializable

K
Kaipeng Deng 已提交
22
from .dataset import DetDataset
K
Kaipeng Deng 已提交
23 24 25

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
26 27 28 29


@register
@serializable
K
Kaipeng Deng 已提交
30
class VOCDataSet(DetDataset):
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    """
    Load dataset with PascalVOC format.

    Notes:
    `anno_path` must contains xml file and image file path for annotations.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): voc annotation file path.
        sample_num (int): number of samples to load, -1 means all.
        label_list (str): if use_default_label is False, will load
            mapping between category and class index.
    """

    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
50
                 data_fields=['image'],
51
                 sample_num=-1,
K
Kaipeng Deng 已提交
52
                 label_list=None):
53
        super(VOCDataSet, self).__init__(
K
Kaipeng Deng 已提交
54
            dataset_dir=dataset_dir,
55 56
            image_dir=image_dir,
            anno_path=anno_path,
57
            data_fields=data_fields,
K
Kaipeng Deng 已提交
58
            sample_num=sample_num)
59 60
        self.label_list = label_list

K
Kaipeng Deng 已提交
61
    def parse_dataset(self, with_background=True):
62 63 64 65 66 67 68 69 70 71 72
        anno_path = os.path.join(self.dataset_dir, self.anno_path)
        image_dir = os.path.join(self.dataset_dir, self.image_dir)

        # mapping category name to class id
        # if with_background is True:
        #   background:0, first_class:1, second_class:2, ...
        # if with_background is False:
        #   first_class:0, second_class:1, ...
        records = []
        ct = 0
        cname2cid = {}
K
Kaipeng Deng 已提交
73
        if self.label_list:
74 75 76 77 78
            label_path = os.path.join(self.dataset_dir, self.label_list)
            if not os.path.exists(label_path):
                raise ValueError("label_list {} does not exists".format(
                    label_path))
            with open(label_path, 'r') as fr:
K
Kaipeng Deng 已提交
79
                label_id = int(with_background)
80 81 82 83
                for line in fr.readlines():
                    cname2cid[line.strip()] = label_id
                    label_id += 1
        else:
K
Kaipeng Deng 已提交
84
            cname2cid = pascalvoc_label(with_background)
85 86 87 88 89 90 91 92

        with open(anno_path, 'r') as fr:
            while True:
                line = fr.readline()
                if not line:
                    break
                img_file, xml_file = [os.path.join(image_dir, x) \
                        for x in line.strip().split()[:2]]
W
wangguanzhong 已提交
93 94 95 96 97
                if not os.path.exists(img_file):
                    logger.warn(
                        'Illegal image file: {}, and it will be ignored'.format(
                            img_file))
                    continue
98
                if not os.path.isfile(xml_file):
W
wangguanzhong 已提交
99 100
                    logger.warn('Illegal xml file: {}, and it will be ignored'.
                                format(xml_file))
101 102 103 104 105 106 107 108 109 110
                    continue
                tree = ET.parse(xml_file)
                if tree.find('id') is None:
                    im_id = np.array([ct])
                else:
                    im_id = np.array([int(tree.find('id').text)])

                objs = tree.findall('object')
                im_w = float(tree.find('size').find('width').text)
                im_h = float(tree.find('size').find('height').text)
W
wangguanzhong 已提交
111 112 113 114 115 116 117 118 119
                if im_w < 0 or im_h < 0:
                    logger.warn(
                        'Illegal width: {} or height: {} in annotation, '
                        'and {} will be ignored'.format(im_w, im_h, xml_file))
                    continue
                gt_bbox = []
                gt_class = []
                gt_score = []
                difficult = []
120 121 122 123 124 125 126 127 128 129 130
                for i, obj in enumerate(objs):
                    cname = obj.find('name').text
                    _difficult = int(obj.find('difficult').text)
                    x1 = float(obj.find('bndbox').find('xmin').text)
                    y1 = float(obj.find('bndbox').find('ymin').text)
                    x2 = float(obj.find('bndbox').find('xmax').text)
                    y2 = float(obj.find('bndbox').find('ymax').text)
                    x1 = max(0, x1)
                    y1 = max(0, y1)
                    x2 = min(im_w - 1, x2)
                    y2 = min(im_h - 1, y2)
W
wangguanzhong 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144
                    if x2 > x1 and y2 > y1:
                        gt_bbox.append([x1, y1, x2, y2])
                        gt_class.append([cname2cid[cname]])
                        gt_score.append([1.])
                        difficult.append([_difficult])
                    else:
                        logger.warn(
                            'Found an invalid bbox in annotations: xml_file: {}'
                            ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                xml_file, x1, y1, x2, y2))
                gt_bbox = np.array(gt_bbox).astype('float32')
                gt_class = np.array(gt_class).astype('int32')
                gt_score = np.array(gt_score).astype('float32')
                difficult = np.array(difficult).astype('int32')
145

146 147 148 149
                voc_rec = {
                    'im_file': img_file,
                    'im_id': im_id,
                    'h': im_h,
150 151 152 153
                    'w': im_w
                } if 'image' in self.data_fields else {}

                gt_rec = {
154 155 156 157 158
                    'gt_class': gt_class,
                    'gt_score': gt_score,
                    'gt_bbox': gt_bbox,
                    'difficult': difficult
                }
159 160 161 162
                for k, v in gt_rec.items():
                    if k in self.data_fields:
                        voc_rec[k] = v

163 164 165 166 167 168 169 170
                if len(objs) != 0:
                    records.append(voc_rec)

                ct += 1
                if self.sample_num > 0 and ct >= self.sample_num:
                    break
        assert len(records) > 0, 'not found any voc record in %s' % (
            self.anno_path)
Y
Yang Zhang 已提交
171
        logger.debug('{} samples in file {}'.format(ct, anno_path))
172 173
        self.roidbs, self.cname2cid = records, cname2cid

K
Kaipeng Deng 已提交
174 175 176
    def get_label_list(self):
        return os.path.join(self.dataset_dir, self.label_list)

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

def pascalvoc_label(with_background=True):
    labels_map = {
        'aeroplane': 1,
        'bicycle': 2,
        'bird': 3,
        'boat': 4,
        'bottle': 5,
        'bus': 6,
        'car': 7,
        'cat': 8,
        'chair': 9,
        'cow': 10,
        'diningtable': 11,
        'dog': 12,
        'horse': 13,
        'motorbike': 14,
        'person': 15,
        'pottedplant': 16,
        'sheep': 17,
        'sofa': 18,
        'train': 19,
        'tvmonitor': 20
    }
    if not with_background:
        labels_map = {k: v - 1 for k, v in labels_map.items()}
    return labels_map