reader.py 8.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import copy
import traceback
K
Kaipeng Deng 已提交
17
import six
G
Guanghua Yu 已提交
18 19 20 21 22 23
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
Q
qingqing01 已提交
24

G
Guanghua Yu 已提交
25
from paddle.io import DataLoader
Q
qingqing01 已提交
26 27
from paddle.io import DistributedBatchSampler

G
Guanghua Yu 已提交
28
from ppdet.core.workspace import register, serializable, create
29 30
from . import transform
from .transform import operator, batch_operator
31

Q
qingqing01 已提交
32 33
from ppdet.utils.logger import setup_logger
logger = setup_logger('reader')
34 35


36
class Compose(object):
37
    def __init__(self, transforms, fields=None, from_=transform,
G
Guanghua Yu 已提交
38
                 num_classes=81):
39
        self.transforms = transforms
G
Guanghua Yu 已提交
40
        self.transforms_cls = []
41
        output_fields = None
G
Guanghua Yu 已提交
42 43 44 45 46 47 48
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

49
                # TODO: should be refined in the future
50 51 52 53 54 55 56 57 58
                if op_cls in [
                        transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
                ]:
                    output_fields = ['image', 'gt_bbox']
                    output_fields.extend([
                        'target{}'.format(i)
                        for i in range(len(v['anchor_masks']))
                    ])

G
Guanghua Yu 已提交
59
        self.fields = fields
60
        self.output_fields = output_fields if output_fields else fields
61 62

    def __call__(self, data):
G
Guanghua Yu 已提交
63 64 65 66 67 68 69
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
70
            try:
G
Guanghua Yu 已提交
71
                data = f(data)
72 73
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
74
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
75 76 77
                            format(f, e, str(stack_info)))
                raise e

78
        if self.output_fields is not None:
G
Guanghua Yu 已提交
79 80 81
            data_new = []
            for item in data:
                batch = []
82
                for k in self.output_fields:
G
Guanghua Yu 已提交
83 84 85 86 87 88 89 90
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
91
            else:
G
Guanghua Yu 已提交
92
                data = data_new
93

G
Guanghua Yu 已提交
94
        return data
95 96


G
Guanghua Yu 已提交
97 98
class BaseDataLoader(object):
    __share__ = ['num_classes']
99 100

    def __init__(self,
G
Guanghua Yu 已提交
101
                 inputs_def=None,
102 103
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
104
                 batch_size=1,
105 106 107
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
108
                 num_classes=81,
109 110
                 with_background=True,
                 **kwargs):
G
Guanghua Yu 已提交
111
        # out fields 
112
        self._fields = inputs_def['fields'] if inputs_def else None
G
Guanghua Yu 已提交
113 114 115
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
116

G
Guanghua Yu 已提交
117
        # batch transfrom 
118 119
        self._batch_transforms = None
        if batch_transforms:
120 121 122
            self._batch_transforms = Compose(batch_transforms,
                                             copy.deepcopy(self._fields),
                                             transform, num_classes)
K
Kaipeng Deng 已提交
123 124 125
            self.output_fields = self._batch_transforms.output_fields
        else:
            self.output_fields = self._fields
G
Guanghua Yu 已提交
126 127 128 129 130

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background
131
        self.kwargs = kwargs
G
Guanghua Yu 已提交
132 133 134 135

    def __call__(self,
                 dataset,
                 worker_num,
136
                 batch_sampler=None,
G
Guanghua Yu 已提交
137 138
                 return_list=False,
                 use_prefetch=True):
K
Kaipeng Deng 已提交
139 140
        self.dataset = dataset
        self.dataset.parse_dataset(self.with_background)
G
Guanghua Yu 已提交
141
        # get data
K
Kaipeng Deng 已提交
142 143
        self.dataset.set_out(self._sample_transforms,
                             copy.deepcopy(self._fields))
144
        # set kwargs
K
Kaipeng Deng 已提交
145
        self.dataset.set_kwargs(**self.kwargs)
G
Guanghua Yu 已提交
146
        # batch sampler
147 148
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
K
Kaipeng Deng 已提交
149
                self.dataset,
150 151 152 153 154
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler
G
Guanghua Yu 已提交
155

D
dengkaipeng 已提交
156
        self.dataloader = DataLoader(
K
Kaipeng Deng 已提交
157
            dataset=self.dataset,
G
Guanghua Yu 已提交
158 159 160 161 162 163
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)
D
dengkaipeng 已提交
164
        self.loader = iter(self.dataloader)
G
Guanghua Yu 已提交
165

K
Kaipeng Deng 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179
        return self

    def __len__(self):
        return len(self._batch_sampler)

    def __iter__(self):
        return self

    def __next__(self):
        # pack {filed_name: field_data} here
        # looking forward to support dictionary
        # data structure in paddle.io.DataLoader
        try:
            data = next(self.loader)
K
Kaipeng Deng 已提交
180
            return {k: v for k, v in zip(self.output_fields, data)}
K
Kaipeng Deng 已提交
181
        except StopIteration:
D
dengkaipeng 已提交
182
            self.loader = iter(self.dataloader)
K
Kaipeng Deng 已提交
183 184 185 186 187
            six.reraise(*sys.exc_info())

    def next(self):
        # python2 compatibility
        return self.__next__()
188 189


G
Guanghua Yu 已提交
190 191 192 193 194 195 196 197
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
198
                 drop_last=True,
G
Guanghua Yu 已提交
199 200
                 drop_empty=True,
                 num_classes=81,
201 202 203 204 205 206
                 with_background=True,
                 **kwargs):
        super(TrainReader, self).__init__(inputs_def, sample_transforms,
                                          batch_transforms, batch_size, shuffle,
                                          drop_last, drop_empty, num_classes,
                                          with_background, **kwargs)
207

K
Kaipeng Deng 已提交
208

G
Guanghua Yu 已提交
209 210 211 212 213 214 215 216
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
217
                 drop_last=True,
G
Guanghua Yu 已提交
218 219
                 drop_empty=True,
                 num_classes=81,
220 221 222 223 224 225
                 with_background=True,
                 **kwargs):
        super(EvalReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)
226

227

G
Guanghua Yu 已提交
228 229 230 231 232 233 234 235 236 237 238
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
239 240 241 242 243 244
                 with_background=True,
                 **kwargs):
        super(TestReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)