dataset.py 5.5 KB
Newer Older
G
Guanghua Yu 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
13 14 15 16
# limitations under the License.

import os
import numpy as np
G
Guanghua Yu 已提交
17
from collections import OrderedDict
18 19 20 21
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
G
Guanghua Yu 已提交
22
from paddle.io import Dataset
23 24
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
W
wangguanzhong 已提交
25
import copy
26 27 28


@serializable
G
Guanghua Yu 已提交
29
class DetDataset(Dataset):
30 31 32 33 34 35 36
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 sample_num=-1,
                 use_default_label=None,
                 **kwargs):
G
Guanghua Yu 已提交
37 38
        super(DetDataset, self).__init__()
        self.dataset_dir = dataset_dir if dataset_dir is not None else ''
39 40 41 42
        self.anno_path = anno_path
        self.image_dir = image_dir if image_dir is not None else ''
        self.sample_num = sample_num
        self.use_default_label = use_default_label
Q
qingqing01 已提交
43
        self._epoch = 0
44

G
Guanghua Yu 已提交
45 46
    def __len__(self, ):
        return len(self.roidbs)
47

G
Guanghua Yu 已提交
48 49
    def __getitem__(self, idx):
        # data batch
W
wangguanzhong 已提交
50
        roidb = copy.deepcopy(self.roidbs[idx])
Q
qingqing01 已提交
51
        if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
52 53 54
            n = len(self.roidbs)
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
Q
qingqing01 已提交
55
        elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
56 57 58
            n = len(self.roidbs)
            idx = np.random.randint(n)
            roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
Q
qingqing01 已提交
59
        elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
60 61 62 63 64 65
            n = len(self.roidbs)
            roidb = [roidb, ] + [
                copy.deepcopy(self.roidbs[np.random.randint(n)])
                for _ in range(3)
            ]

G
Guanghua Yu 已提交
66 67 68 69 70 71 72
        # data augment
        roidb = self.transform(roidb)
        # data item 
        out = OrderedDict()
        for k in self.fields:
            out[k] = roidb[k]
        return out.values()
73

74 75 76 77 78
    def set_kwargs(self, **kwargs):
        self.mixup_epoch = kwargs.get('mixup_epoch', -1)
        self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
        self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)

Q
qingqing01 已提交
79 80 81
    def set_epoch(self, epoch_id):
        self._epoch = epoch_id

G
Guanghua Yu 已提交
82 83 84
    def set_out(self, sample_transform, fields):
        self.transform = sample_transform
        self.fields = fields
85

G
Guanghua Yu 已提交
86 87 88
    def parse_dataset(self, with_background=True):
        raise NotImplemented(
            "Need to implement parse_dataset method of Dataset")
89 90

    def get_anno(self):
W
wangguanzhong 已提交
91 92
        if self.anno_path is None:
            return
93 94 95
        return os.path.join(self.dataset_dir, self.anno_path)


G
Guanghua Yu 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
    return f.lower().endswith(extensions)


def _make_dataset(dir):
    dir = os.path.expanduser(dir)
    if not os.path.isdir(d):
        raise ('{} should be a dir'.format(dir))
    images = []
    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
            if is_valid_file(path):
                images.append(path)
    return images


113 114
@register
@serializable
G
Guanghua Yu 已提交
115
class ImageFolder(DetDataset):
116 117 118 119 120
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 sample_num=-1,
G
Guanghua Yu 已提交
121
                 use_default_label=None,
122
                 **kwargs):
123
        super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path,
G
Guanghua Yu 已提交
124 125 126
                                          sample_num, use_default_label)
        self._imid2path = {}
        self.roidbs = None
127

G
Guanghua Yu 已提交
128 129 130 131 132
    def parse_dataset(self, with_background=True):
        if not self.roidbs:
            self.roidbs = self._load_images()

    def _parse(self):
133 134 135 136 137 138 139 140 141 142
        image_dir = self.image_dir
        if not isinstance(image_dir, Sequence):
            image_dir = [image_dir]
        images = []
        for im_dir in image_dir:
            if os.path.isdir(im_dir):
                im_dir = os.path.join(self.dataset_dir, im_dir)
                images.extend(_make_dataset(im_dir))
            elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
                images.append(im_dir)
G
Guanghua Yu 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        return images

    def _load_images(self):
        images = self._parse()
        ct = 0
        records = []
        for image in images:
            assert image != '' and os.path.isfile(image), \
                    "Image {} not found".format(image)
            if self.sample_num > 0 and ct >= self.sample_num:
                break
            rec = {'im_id': np.array([ct]), 'im_file': image}
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

    def get_imid2path(self):
        return self._imid2path

    def set_images(self, images):
        self.image_dir = images
        self.roidbs = self._load_images()