dataset.py 4.5 KB
Newer Older
G
Guanghua Yu 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
13 14 15 16
# limitations under the License.

import os
import numpy as np
G
Guanghua Yu 已提交
17
from collections import OrderedDict
18 19 20 21
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
G
Guanghua Yu 已提交
22
from paddle.io import Dataset
23 24
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
W
wangguanzhong 已提交
25
import copy
26 27 28


@serializable
G
Guanghua Yu 已提交
29
class DetDataset(Dataset):
30 31 32 33 34 35 36
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 sample_num=-1,
                 use_default_label=None,
                 **kwargs):
G
Guanghua Yu 已提交
37 38
        super(DetDataset, self).__init__()
        self.dataset_dir = dataset_dir if dataset_dir is not None else ''
39 40 41 42 43
        self.anno_path = anno_path
        self.image_dir = image_dir if image_dir is not None else ''
        self.sample_num = sample_num
        self.use_default_label = use_default_label

G
Guanghua Yu 已提交
44 45
    def __len__(self, ):
        return len(self.roidbs)
46

G
Guanghua Yu 已提交
47 48
    def __getitem__(self, idx):
        # data batch
W
wangguanzhong 已提交
49
        roidb = copy.deepcopy(self.roidbs[idx])
G
Guanghua Yu 已提交
50 51 52 53 54 55 56
        # data augment
        roidb = self.transform(roidb)
        # data item 
        out = OrderedDict()
        for k in self.fields:
            out[k] = roidb[k]
        return out.values()
57

G
Guanghua Yu 已提交
58 59 60
    def set_out(self, sample_transform, fields):
        self.transform = sample_transform
        self.fields = fields
61

G
Guanghua Yu 已提交
62 63 64
    def parse_dataset(self, with_background=True):
        raise NotImplemented(
            "Need to implement parse_dataset method of Dataset")
65 66

    def get_anno(self):
W
wangguanzhong 已提交
67 68
        if self.anno_path is None:
            return
69 70 71
        return os.path.join(self.dataset_dir, self.anno_path)


G
Guanghua Yu 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
    return f.lower().endswith(extensions)


def _make_dataset(dir):
    dir = os.path.expanduser(dir)
    if not os.path.isdir(d):
        raise ('{} should be a dir'.format(dir))
    images = []
    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
        for fname in sorted(fnames):
            path = os.path.join(root, fname)
            if is_valid_file(path):
                images.append(path)
    return images


89 90
@register
@serializable
G
Guanghua Yu 已提交
91
class ImageFolder(DetDataset):
92 93 94 95 96
    def __init__(self,
                 dataset_dir=None,
                 image_dir=None,
                 anno_path=None,
                 sample_num=-1,
G
Guanghua Yu 已提交
97
                 use_default_label=None,
98
                 **kwargs):
99
        super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path,
G
Guanghua Yu 已提交
100 101 102
                                          sample_num, use_default_label)
        self._imid2path = {}
        self.roidbs = None
103

G
Guanghua Yu 已提交
104 105 106 107 108
    def parse_dataset(self, with_background=True):
        if not self.roidbs:
            self.roidbs = self._load_images()

    def _parse(self):
109 110 111 112 113 114 115 116 117 118
        image_dir = self.image_dir
        if not isinstance(image_dir, Sequence):
            image_dir = [image_dir]
        images = []
        for im_dir in image_dir:
            if os.path.isdir(im_dir):
                im_dir = os.path.join(self.dataset_dir, im_dir)
                images.extend(_make_dataset(im_dir))
            elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
                images.append(im_dir)
G
Guanghua Yu 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        return images

    def _load_images(self):
        images = self._parse()
        ct = 0
        records = []
        for image in images:
            assert image != '' and os.path.isfile(image), \
                    "Image {} not found".format(image)
            if self.sample_num > 0 and ct >= self.sample_num:
                break
            rec = {'im_id': np.array([ct]), 'im_file': image}
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

    def get_imid2path(self):
        return self._imid2path

    def set_images(self, images):
        self.image_dir = images
        self.roidbs = self._load_images()