reader.py 7.6 KB
Newer Older
1 2
import copy
import traceback
3
import logging
G
Guanghua Yu 已提交
4
import threading
K
Kaipeng Deng 已提交
5
import six
G
Guanghua Yu 已提交
6 7 8 9 10 11 12 13 14
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
15 16
from . import transform
from .transform import operator, batch_operator
17 18 19 20

logger = logging.getLogger(__name__)


21
class Compose(object):
22
    def __init__(self, transforms, fields=None, from_=transform,
G
Guanghua Yu 已提交
23
                 num_classes=81):
24
        self.transforms = transforms
G
Guanghua Yu 已提交
25
        self.transforms_cls = []
26
        output_fields = None
G
Guanghua Yu 已提交
27 28 29 30 31 32 33
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

34
                # TODO: should be refined in the future
35 36 37 38 39 40 41 42 43
                if op_cls in [
                        transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
                ]:
                    output_fields = ['image', 'gt_bbox']
                    output_fields.extend([
                        'target{}'.format(i)
                        for i in range(len(v['anchor_masks']))
                    ])

G
Guanghua Yu 已提交
44
        self.fields = fields
45
        self.output_fields = output_fields if output_fields else fields
46 47

    def __call__(self, data):
G
Guanghua Yu 已提交
48 49 50 51 52 53 54
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
55
            try:
G
Guanghua Yu 已提交
56
                data = f(data)
57 58
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
59
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
60 61 62
                            format(f, e, str(stack_info)))
                raise e

63
        if self.output_fields is not None:
G
Guanghua Yu 已提交
64 65 66
            data_new = []
            for item in data:
                batch = []
67
                for k in self.output_fields:
G
Guanghua Yu 已提交
68 69 70 71 72 73 74 75
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
76
            else:
G
Guanghua Yu 已提交
77
                data = data_new
78

G
Guanghua Yu 已提交
79
        return data
80 81


G
Guanghua Yu 已提交
82 83
class BaseDataLoader(object):
    __share__ = ['num_classes']
84 85

    def __init__(self,
G
Guanghua Yu 已提交
86
                 inputs_def=None,
87 88
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
89
                 batch_size=1,
90 91 92
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
93
                 num_classes=81,
94 95
                 with_background=True,
                 **kwargs):
G
Guanghua Yu 已提交
96
        # out fields 
97
        self._fields = inputs_def['fields'] if inputs_def else None
G
Guanghua Yu 已提交
98 99 100
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
101

G
Guanghua Yu 已提交
102
        # batch transfrom 
103 104
        self._batch_transforms = None
        if batch_transforms:
105 106 107
            self._batch_transforms = Compose(batch_transforms,
                                             copy.deepcopy(self._fields),
                                             transform, num_classes)
K
Kaipeng Deng 已提交
108 109 110
            self.output_fields = self._batch_transforms.output_fields
        else:
            self.output_fields = self._fields
G
Guanghua Yu 已提交
111 112 113 114 115

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background
116
        self.kwargs = kwargs
G
Guanghua Yu 已提交
117 118 119 120

    def __call__(self,
                 dataset,
                 worker_num,
121 122
                 device=None,
                 batch_sampler=None,
G
Guanghua Yu 已提交
123 124
                 return_list=False,
                 use_prefetch=True):
K
Kaipeng Deng 已提交
125 126
        self.dataset = dataset
        self.dataset.parse_dataset(self.with_background)
G
Guanghua Yu 已提交
127
        # get data
K
Kaipeng Deng 已提交
128 129
        self.dataset.set_out(self._sample_transforms,
                             copy.deepcopy(self._fields))
130
        # set kwargs
K
Kaipeng Deng 已提交
131
        self.dataset.set_kwargs(**self.kwargs)
G
Guanghua Yu 已提交
132
        # batch sampler
133 134
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
K
Kaipeng Deng 已提交
135
                self.dataset,
136 137 138 139 140
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler
G
Guanghua Yu 已提交
141

D
dengkaipeng 已提交
142
        self.dataloader = DataLoader(
K
Kaipeng Deng 已提交
143
            dataset=self.dataset,
G
Guanghua Yu 已提交
144 145 146 147 148 149 150
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            places=device,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)
D
dengkaipeng 已提交
151
        self.loader = iter(self.dataloader)
G
Guanghua Yu 已提交
152

K
Kaipeng Deng 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166
        return self

    def __len__(self):
        return len(self._batch_sampler)

    def __iter__(self):
        return self

    def __next__(self):
        # pack {filed_name: field_data} here
        # looking forward to support dictionary
        # data structure in paddle.io.DataLoader
        try:
            data = next(self.loader)
K
Kaipeng Deng 已提交
167
            return {k: v for k, v in zip(self.output_fields, data)}
K
Kaipeng Deng 已提交
168
        except StopIteration:
D
dengkaipeng 已提交
169
            self.loader = iter(self.dataloader)
K
Kaipeng Deng 已提交
170 171 172 173 174
            six.reraise(*sys.exc_info())

    def next(self):
        # python2 compatibility
        return self.__next__()
175 176


G
Guanghua Yu 已提交
177 178 179 180 181 182 183 184
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
185
                 drop_last=True,
G
Guanghua Yu 已提交
186 187
                 drop_empty=True,
                 num_classes=81,
188 189 190 191 192 193
                 with_background=True,
                 **kwargs):
        super(TrainReader, self).__init__(inputs_def, sample_transforms,
                                          batch_transforms, batch_size, shuffle,
                                          drop_last, drop_empty, num_classes,
                                          with_background, **kwargs)
194

K
Kaipeng Deng 已提交
195

G
Guanghua Yu 已提交
196 197 198 199 200 201 202 203
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
204
                 drop_last=True,
G
Guanghua Yu 已提交
205 206
                 drop_empty=True,
                 num_classes=81,
207 208 209 210 211 212
                 with_background=True,
                 **kwargs):
        super(EvalReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)
213

214

G
Guanghua Yu 已提交
215 216 217 218 219 220 221 222 223 224 225
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
226 227 228 229 230 231
                 with_background=True,
                 **kwargs):
        super(TestReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)