reader.py 7.4 KB
Newer Older
1 2
import copy
import traceback
3
import logging
G
Guanghua Yu 已提交
4
import threading
K
Kaipeng Deng 已提交
5
import six
G
Guanghua Yu 已提交
6 7 8 9 10 11 12 13 14
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
15 16
from . import transform
from .transform import operator, batch_operator
17 18 19 20

logger = logging.getLogger(__name__)


21
class Compose(object):
22
    def __init__(self, transforms, fields=None, from_=transform,
G
Guanghua Yu 已提交
23
                 num_classes=81):
24
        self.transforms = transforms
G
Guanghua Yu 已提交
25
        self.transforms_cls = []
26
        output_fields = None
G
Guanghua Yu 已提交
27 28 29 30 31 32 33
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

34
                # TODO: should be refined in the future
35 36 37 38 39 40 41 42 43
                if op_cls in [
                        transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
                ]:
                    output_fields = ['image', 'gt_bbox']
                    output_fields.extend([
                        'target{}'.format(i)
                        for i in range(len(v['anchor_masks']))
                    ])

G
Guanghua Yu 已提交
44
        self.fields = fields
45
        self.output_fields = output_fields if output_fields else fields
46 47

    def __call__(self, data):
G
Guanghua Yu 已提交
48 49 50 51 52 53 54
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
55
            try:
G
Guanghua Yu 已提交
56
                data = f(data)
57 58
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
59
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
60 61 62
                            format(f, e, str(stack_info)))
                raise e

63
        if self.output_fields is not None:
G
Guanghua Yu 已提交
64 65 66
            data_new = []
            for item in data:
                batch = []
67
                for k in self.output_fields:
G
Guanghua Yu 已提交
68 69 70 71 72 73 74 75
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
76
            else:
G
Guanghua Yu 已提交
77
                data = data_new
78

G
Guanghua Yu 已提交
79
        return data
80 81


G
Guanghua Yu 已提交
82 83
class BaseDataLoader(object):
    __share__ = ['num_classes']
84 85

    def __init__(self,
G
Guanghua Yu 已提交
86
                 inputs_def=None,
87 88
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
89
                 batch_size=1,
90 91 92
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
93
                 num_classes=81,
94 95
                 with_background=True,
                 **kwargs):
G
Guanghua Yu 已提交
96
        # out fields 
97
        self._fields = inputs_def['fields'] if inputs_def else None
G
Guanghua Yu 已提交
98 99 100
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
101

G
Guanghua Yu 已提交
102
        # batch transfrom 
103 104
        self._batch_transforms = None
        if batch_transforms:
105 106 107
            self._batch_transforms = Compose(batch_transforms,
                                             copy.deepcopy(self._fields),
                                             transform, num_classes)
G
Guanghua Yu 已提交
108 109 110 111 112

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background
113
        self.kwargs = kwargs
G
Guanghua Yu 已提交
114 115 116 117

    def __call__(self,
                 dataset,
                 worker_num,
118 119
                 device=None,
                 batch_sampler=None,
G
Guanghua Yu 已提交
120 121
                 return_list=False,
                 use_prefetch=True):
K
Kaipeng Deng 已提交
122 123
        self.dataset = dataset
        self.dataset.parse_dataset(self.with_background)
G
Guanghua Yu 已提交
124
        # get data
K
Kaipeng Deng 已提交
125 126
        self.dataset.set_out(self._sample_transforms,
                             copy.deepcopy(self._fields))
127
        # set kwargs
K
Kaipeng Deng 已提交
128
        self.dataset.set_kwargs(**self.kwargs)
G
Guanghua Yu 已提交
129
        # batch sampler
130 131
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
K
Kaipeng Deng 已提交
132
                self.dataset,
133 134 135 136 137
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler
G
Guanghua Yu 已提交
138

K
Kaipeng Deng 已提交
139 140
        self.loader = DataLoader(
            dataset=self.dataset,
G
Guanghua Yu 已提交
141 142 143 144 145 146 147
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            places=device,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)
K
Kaipeng Deng 已提交
148
        self.loader = iter(self.loader)
G
Guanghua Yu 已提交
149

K
Kaipeng Deng 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        return self

    def __len__(self):
        return len(self._batch_sampler)

    def __iter__(self):
        return self

    def __next__(self):
        # pack {filed_name: field_data} here
        # looking forward to support dictionary
        # data structure in paddle.io.DataLoader
        try:
            data = next(self.loader)
            return {k: v for k, v in zip(self._fields, data)}
        except StopIteration:
            six.reraise(*sys.exc_info())

    def next(self):
        # python2 compatibility
        return self.__next__()
171 172


G
Guanghua Yu 已提交
173 174 175 176 177 178 179 180
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
181
                 drop_last=True,
G
Guanghua Yu 已提交
182 183
                 drop_empty=True,
                 num_classes=81,
184 185 186 187 188 189
                 with_background=True,
                 **kwargs):
        super(TrainReader, self).__init__(inputs_def, sample_transforms,
                                          batch_transforms, batch_size, shuffle,
                                          drop_last, drop_empty, num_classes,
                                          with_background, **kwargs)
190

K
Kaipeng Deng 已提交
191

G
Guanghua Yu 已提交
192 193 194 195 196 197 198 199
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
200
                 drop_last=True,
G
Guanghua Yu 已提交
201 202
                 drop_empty=True,
                 num_classes=81,
203 204 205 206 207 208
                 with_background=True,
                 **kwargs):
        super(EvalReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)
209

210

G
Guanghua Yu 已提交
211 212 213 214 215 216 217 218 219 220 221
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
222 223 224 225 226 227
                 with_background=True,
                 **kwargs):
        super(TestReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)