coco.py 10.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import os
import numpy as np
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)


@register
@serializable
class COCODataSet(DetDataset):
F
Feng Ni 已提交
27 28 29 30 31 32 33 34 35
    """
    Load dataset with COCO format.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): coco annotation file path.
        data_fields (list): key name of data dictionary, at least have 'image'.
        sample_num (int): number of samples to load, -1 means all.
36 37 38 39 40
        load_crowd (bool): whether to load crowded ground-truth. 
            False as default
        allow_empty (bool): whether to load empty entry. False as default
        empty_ratio (float): the ratio of empty record number to total 
            record's, if empty_ratio is out of [0. ,1.), do not sample the 
41
            records and use all the empty entries. 1. as default
42
        repeat (int): repeat times for dataset, use in benchmark.
F
Feng Ni 已提交
43 44
    """

Q
qingqing01 已提交
45 46 47 48 49
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 data_fields=['image'],
50 51
                 sample_num=-1,
                 load_crowd=False,
52
                 allow_empty=False,
53 54 55 56 57 58 59 60 61
                 empty_ratio=1.,
                 repeat=1):
        super(COCODataSet, self).__init__(
            dataset_dir,
            image_dir,
            anno_path,
            data_fields,
            sample_num,
            repeat=repeat)
Q
qingqing01 已提交
62 63
        self.load_image_only = False
        self.load_semantic = False
64 65 66 67 68 69 70 71 72
        self.load_crowd = load_crowd
        self.allow_empty = allow_empty
        self.empty_ratio = empty_ratio

    def _sample_empty(self, records, num):
        # if empty_ratio is out of [0. ,1.), do not sample the records
        if self.empty_ratio < 0. or self.empty_ratio >= 1.:
            return records
        import random
73 74
        sample_num = min(
            int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
75 76
        records = random.sample(records, sample_num)
        return records
Q
qingqing01 已提交
77

78
    def parse_dataset(self):
Q
qingqing01 已提交
79 80 81 82 83 84 85 86
        anno_path = os.path.join(self.dataset_dir, self.anno_path)
        image_dir = os.path.join(self.dataset_dir, self.image_dir)

        assert anno_path.endswith('.json'), \
            'invalid coco annotation file: ' + anno_path
        from pycocotools.coco import COCO
        coco = COCO(anno_path)
        img_ids = coco.getImgIds()
87
        img_ids.sort()
Q
qingqing01 已提交
88 89
        cat_ids = coco.getCatIds()
        records = []
90
        empty_records = []
Q
qingqing01 已提交
91 92
        ct = 0

K
Kaipeng Deng 已提交
93 94
        self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
        self.cname2cid = dict({
Q
qingqing01 已提交
95
            coco.loadCats(catid)[0]['name']: clsid
K
Kaipeng Deng 已提交
96
            for catid, clsid in self.catid2clsid.items()
Q
qingqing01 已提交
97 98 99 100 101 102 103 104
        })

        if 'annotations' not in coco.dataset:
            self.load_image_only = True
            logger.warning('Annotation file: {} does not contains ground truth '
                           'and load image information only.'.format(anno_path))

        for img_id in img_ids:
105
            img_anno = coco.loadImgs([img_id])[0]
Q
qingqing01 已提交
106 107 108 109 110 111
            im_fname = img_anno['file_name']
            im_w = float(img_anno['width'])
            im_h = float(img_anno['height'])

            im_path = os.path.join(image_dir,
                                   im_fname) if image_dir else im_fname
112
            is_empty = False
Q
qingqing01 已提交
113 114 115 116 117 118 119 120 121 122 123
            if not os.path.exists(im_path):
                logger.warning('Illegal image file: {}, and it will be '
                               'ignored'.format(im_path))
                continue

            if im_w < 0 or im_h < 0:
                logger.warning('Illegal width: {} or height: {} in annotation, '
                               'and im_id: {} will be ignored'.format(
                                   im_w, im_h, img_id))
                continue

124 125 126 127 128 129 130
            coco_rec = {
                'im_file': im_path,
                'im_id': np.array([img_id]),
                'h': im_h,
                'w': im_w,
            } if 'image' in self.data_fields else {}

Q
qingqing01 已提交
131
            if not self.load_image_only:
132 133
                ins_anno_ids = coco.getAnnIds(
                    imgIds=[img_id], iscrowd=None if self.load_crowd else False)
Q
qingqing01 已提交
134 135 136
                instances = coco.loadAnns(ins_anno_ids)

                bboxes = []
137
                is_rbox_anno = False
Q
qingqing01 已提交
138 139
                for inst in instances:
                    # check gt bbox
140 141
                    if inst.get('ignore', False):
                        continue
Q
qingqing01 已提交
142 143 144 145 146
                    if 'bbox' not in inst.keys():
                        continue
                    else:
                        if not any(np.array(inst['bbox'])):
                            continue
C
cnn 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159

                    # read rbox anno or not
                    is_rbox_anno = True if len(inst['bbox']) == 5 else False
                    if is_rbox_anno:
                        xc, yc, box_w, box_h, angle = inst['bbox']
                        x1 = xc - box_w / 2.0
                        y1 = yc - box_h / 2.0
                        x2 = x1 + box_w
                        y2 = y1 + box_h
                    else:
                        x1, y1, box_w, box_h = inst['bbox']
                        x2 = x1 + box_w
                        y2 = y1 + box_h
160 161 162 163 164
                    eps = 1e-5
                    if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
                        inst['clean_bbox'] = [
                            round(float(x), 3) for x in [x1, y1, x2, y2]
                        ]
C
cnn 已提交
165 166
                        if is_rbox_anno:
                            inst['clean_rbox'] = [xc, yc, box_w, box_h, angle]
Q
qingqing01 已提交
167 168 169 170 171 172 173 174
                        bboxes.append(inst)
                    else:
                        logger.warning(
                            'Found an invalid bbox in annotations: im_id: {}, '
                            'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                img_id, float(inst['area']), x1, y1, x2, y2))

                num_bbox = len(bboxes)
175
                if num_bbox <= 0 and not self.allow_empty:
Q
qingqing01 已提交
176
                    continue
177 178
                elif num_bbox <= 0:
                    is_empty = True
Q
qingqing01 已提交
179 180

                gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
C
cnn 已提交
181 182
                if is_rbox_anno:
                    gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32)
Q
qingqing01 已提交
183 184 185 186 187 188 189
                gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
                is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
                gt_poly = [None] * num_bbox

                has_segmentation = False
                for i, box in enumerate(bboxes):
                    catid = box['category_id']
K
Kaipeng Deng 已提交
190
                    gt_class[i][0] = self.catid2clsid[catid]
Q
qingqing01 已提交
191
                    gt_bbox[i, :] = box['clean_bbox']
C
cnn 已提交
192 193 194
                    # xc, yc, w, h, theta
                    if is_rbox_anno:
                        gt_rbox[i, :] = box['clean_rbox']
Q
qingqing01 已提交
195 196 197
                    is_crowd[i][0] = box['iscrowd']
                    # check RLE format 
                    if 'segmentation' in box and box['iscrowd'] == 1:
198
                        gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
199
                    elif 'segmentation' in box and box['segmentation']:
200 201 202 203 204 205 206 207 208
                        if not np.array(box['segmentation']
                                        ).size > 0 and not self.allow_empty:
                            bboxes.pop(i)
                            gt_poly.pop(i)
                            np.delete(is_crowd, i)
                            np.delete(gt_class, i)
                            np.delete(gt_bbox, i)
                        else:
                            gt_poly[i] = box['segmentation']
Q
qingqing01 已提交
209 210
                        has_segmentation = True

211 212
                if has_segmentation and not any(
                        gt_poly) and not self.allow_empty:
Q
qingqing01 已提交
213 214
                    continue

C
cnn 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
                if is_rbox_anno:
                    gt_rec = {
                        'is_crowd': is_crowd,
                        'gt_class': gt_class,
                        'gt_bbox': gt_bbox,
                        'gt_rbox': gt_rbox,
                        'gt_poly': gt_poly,
                    }
                else:
                    gt_rec = {
                        'is_crowd': is_crowd,
                        'gt_class': gt_class,
                        'gt_bbox': gt_bbox,
                        'gt_poly': gt_poly,
                    }

Q
qingqing01 已提交
231 232 233 234 235 236 237 238 239 240 241 242
                for k, v in gt_rec.items():
                    if k in self.data_fields:
                        coco_rec[k] = v

                # TODO: remove load_semantic
                if self.load_semantic and 'semantic' in self.data_fields:
                    seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
                                            'train2017', im_fname[:-3] + 'png')
                    coco_rec.update({'semantic': seg_path})

            logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
                im_path, img_id, im_h, im_w))
243 244 245 246
            if is_empty:
                empty_records.append(coco_rec)
            else:
                records.append(coco_rec)
Q
qingqing01 已提交
247 248 249
            ct += 1
            if self.sample_num > 0 and ct >= self.sample_num:
                break
250
        assert ct > 0, 'not found any coco record in %s' % (anno_path)
Q
qingqing01 已提交
251
        logger.debug('{} samples in file {}'.format(ct, anno_path))
252
        if self.allow_empty and len(empty_records) > 0:
253 254
            empty_records = self._sample_empty(empty_records, len(records))
            records += empty_records
K
Kaipeng Deng 已提交
255
        self.roidbs = records