dataset.py 6.6 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import os
import numpy as np
from collections import OrderedDict
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
from paddle.io import Dataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
import copy


@serializable
class DetDataset(Dataset):
F
Feng Ni 已提交
30 31 32 33 34 35 36 37 38 39 40 41
    """
    Load detection dataset.

    Args:
        dataset_dir (str): root directory for dataset.
        image_dir (str): directory for images.
        anno_path (str): annotation file path.
        data_fields (list): key name of data dictionary, at least have 'image'.
        sample_num (int): number of samples to load, -1 means all.
        use_default_label (bool): whether to load default label list.
    """

Q
qingqing01 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 data_fields=['image'],
                 sample_num=-1,
                 use_default_label=None,
                 **kwargs):
        super(DetDataset, self).__init__()
        self.dataset_dir = dataset_dir if dataset_dir is not None else ''
        self.anno_path = anno_path
        self.image_dir = image_dir if image_dir is not None else ''
        self.data_fields = data_fields
        self.sample_num = sample_num
        self.use_default_label = use_default_label
        self._epoch = 0
W
wangguanzhong 已提交
58
        self._curr_iter = 0
Q
qingqing01 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    def __len__(self, ):
        return len(self.roidbs)

    def __getitem__(self, idx):
        # data batch
        roidb = copy.deepcopy(self.roidbs[idx])
        if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
            n = len(self.roidbs)
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
        elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
            n = len(self.roidbs)
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
        elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
            n = len(self.roidbs)
            roidb = [roidb, ] + [
                copy.deepcopy(self.roidbs[np.random.randint(n)])
                for _ in range(3)
            ]
W
wangguanzhong 已提交
80 81 82 83 84
        if isinstance(roidb, Sequence):
            for r in roidb:
                r['curr_iter'] = self._curr_iter
        else:
            roidb['curr_iter'] = self._curr_iter
W
wangguanzhong 已提交
85
        self._curr_iter += 1
Q
qingqing01 已提交
86 87 88

        return self.transform(roidb)

K
Kaipeng Deng 已提交
89
    def check_or_download_dataset(self):
G
George Ni 已提交
90 91 92 93 94 95 96
        if isinstance(self.anno_path, list):
            for path in self.anno_path:
                self.dataset_dir = get_dataset_path(self.dataset_dir, path,
                                                    self.image_dir)
        else:
            self.dataset_dir = get_dataset_path(self.dataset_dir,
                                                self.anno_path, self.image_dir)
K
Kaipeng Deng 已提交
97

Q
qingqing01 已提交
98 99 100 101 102 103 104 105 106 107 108
    def set_kwargs(self, **kwargs):
        self.mixup_epoch = kwargs.get('mixup_epoch', -1)
        self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
        self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)

    def set_transform(self, transform):
        self.transform = transform

    def set_epoch(self, epoch_id):
        self._epoch = epoch_id

109
    def parse_dataset(self, ):
110
        raise NotImplementedError(
Q
qingqing01 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124
            "Need to implement parse_dataset method of Dataset")

    def get_anno(self):
        if self.anno_path is None:
            return
        return os.path.join(self.dataset_dir, self.anno_path)


def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
    return f.lower().endswith(extensions)


def _make_dataset(dir):
    dir = os.path.expanduser(dir)
125
    if not os.path.isdir(dir):
Q
qingqing01 已提交
126 127 128 129 130
        raise ('{} should be a dir'.format(dir))
    images = []
    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
131
            if _is_valid_file(path):
Q
qingqing01 已提交
132 133 134 135 136 137 138 139 140 141 142 143
                images.append(path)
    return images


@register
@serializable
class ImageFolder(DetDataset):
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 sample_num=-1,
                 use_default_label=None,
G
George Ni 已提交
144
                 keep_ori_im=False,
Q
qingqing01 已提交
145
                 **kwargs):
146 147 148 149 150
        super(ImageFolder, self).__init__(
            dataset_dir,
            image_dir,
            sample_num=sample_num,
            use_default_label=use_default_label)
G
George Ni 已提交
151
        self.keep_ori_im = keep_ori_im
Q
qingqing01 已提交
152 153 154
        self._imid2path = {}
        self.roidbs = None

155 156 157
    def check_or_download_dataset(self):
        return

158
    def parse_dataset(self, ):
Q
qingqing01 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        if not self.roidbs:
            self.roidbs = self._load_images()

    def _parse(self):
        image_dir = self.image_dir
        if not isinstance(image_dir, Sequence):
            image_dir = [image_dir]
        images = []
        for im_dir in image_dir:
            if os.path.isdir(im_dir):
                im_dir = os.path.join(self.dataset_dir, im_dir)
                images.extend(_make_dataset(im_dir))
            elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
                images.append(im_dir)
        return images

    def _load_images(self):
        images = self._parse()
        ct = 0
        records = []
        for image in images:
            assert image != '' and os.path.isfile(image), \
                    "Image {} not found".format(image)
G
Guanghua Yu 已提交
182
            if self.sample_num > 0 and ct >= self.sample_num:
Q
qingqing01 已提交
183 184
                break
            rec = {'im_id': np.array([ct]), 'im_file': image}
G
George Ni 已提交
185 186
            if self.keep_ori_im:
                rec.update({'keep_ori_im': 1})
Q
qingqing01 已提交
187 188 189 190 191 192 193 194 195 196 197 198
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

    def get_imid2path(self):
        return self._imid2path

    def set_images(self, images):
        self.image_dir = images
        self.roidbs = self._load_images()