reader.py 6.9 KB
Newer Older
1 2
import copy
import traceback
3
import logging
G
Guanghua Yu 已提交
4 5 6 7 8 9 10 11 12 13
import threading
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
14 15
from . import transform
from .transform import operator, batch_operator
16 17 18 19

logger = logging.getLogger(__name__)


20
class Compose(object):
21
    def __init__(self, transforms, fields=None, from_=transform,
G
Guanghua Yu 已提交
22
                 num_classes=81):
23
        self.transforms = transforms
G
Guanghua Yu 已提交
24
        self.transforms_cls = []
25
        output_fields = None
G
Guanghua Yu 已提交
26 27 28 29 30 31 32
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

33
                # TODO: should be refined in the future
34 35 36 37 38 39 40 41 42
                if op_cls in [
                        transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
                ]:
                    output_fields = ['image', 'gt_bbox']
                    output_fields.extend([
                        'target{}'.format(i)
                        for i in range(len(v['anchor_masks']))
                    ])

G
Guanghua Yu 已提交
43
        self.fields = fields
44
        self.output_fields = output_fields if output_fields else fields
45 46

    def __call__(self, data):
G
Guanghua Yu 已提交
47 48 49 50 51 52 53
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
54
            try:
G
Guanghua Yu 已提交
55
                data = f(data)
56 57
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
58
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
59 60 61
                            format(f, e, str(stack_info)))
                raise e

62
        if self.output_fields is not None:
G
Guanghua Yu 已提交
63 64 65
            data_new = []
            for item in data:
                batch = []
66
                for k in self.output_fields:
G
Guanghua Yu 已提交
67 68 69 70 71 72 73 74
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
75
            else:
G
Guanghua Yu 已提交
76
                data = data_new
77

G
Guanghua Yu 已提交
78
        return data
79 80


G
Guanghua Yu 已提交
81 82
class BaseDataLoader(object):
    __share__ = ['num_classes']
83 84

    def __init__(self,
G
Guanghua Yu 已提交
85
                 inputs_def=None,
86 87
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
88
                 batch_size=1,
89 90 91
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
92
                 num_classes=81,
93 94
                 with_background=True,
                 **kwargs):
G
Guanghua Yu 已提交
95
        # out fields 
96
        self._fields = inputs_def['fields'] if inputs_def else None
G
Guanghua Yu 已提交
97 98 99
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
100

G
Guanghua Yu 已提交
101
        # batch transfrom 
102 103
        self._batch_transforms = None
        if batch_transforms:
104 105 106
            self._batch_transforms = Compose(batch_transforms,
                                             copy.deepcopy(self._fields),
                                             transform, num_classes)
G
Guanghua Yu 已提交
107 108 109 110 111

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background
112
        self.kwargs = kwargs
G
Guanghua Yu 已提交
113 114 115 116

    def __call__(self,
                 dataset,
                 worker_num,
117 118
                 device=None,
                 batch_sampler=None,
G
Guanghua Yu 已提交
119 120 121 122 123
                 return_list=False,
                 use_prefetch=True):
        self._dataset = dataset
        self._dataset.parse_dataset(self.with_background)
        # get data
124 125
        self._dataset.set_out(self._sample_transforms,
                              copy.deepcopy(self._fields))
126 127
        # set kwargs
        self._dataset.set_kwargs(**self.kwargs)
G
Guanghua Yu 已提交
128
        # batch sampler
129 130 131 132 133 134 135 136
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
                self._dataset,
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler
G
Guanghua Yu 已提交
137 138 139 140 141 142 143 144 145 146 147 148

        loader = DataLoader(
            dataset=self._dataset,
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            places=device,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)

        return loader, len(self._batch_sampler)
149 150


G
Guanghua Yu 已提交
151 152 153 154 155 156 157 158
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
159
                 drop_last=True,
G
Guanghua Yu 已提交
160 161
                 drop_empty=True,
                 num_classes=81,
162 163 164 165 166 167
                 with_background=True,
                 **kwargs):
        super(TrainReader, self).__init__(inputs_def, sample_transforms,
                                          batch_transforms, batch_size, shuffle,
                                          drop_last, drop_empty, num_classes,
                                          with_background, **kwargs)
168

K
Kaipeng Deng 已提交
169

G
Guanghua Yu 已提交
170 171 172 173 174 175 176 177
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
178
                 drop_last=True,
G
Guanghua Yu 已提交
179 180
                 drop_empty=True,
                 num_classes=81,
181 182 183 184 185 186
                 with_background=True,
                 **kwargs):
        super(EvalReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)
187

188

G
Guanghua Yu 已提交
189 190 191 192 193 194 195 196 197 198 199
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
200 201 202 203 204 205
                 with_background=True,
                 **kwargs):
        super(TestReader, self).__init__(inputs_def, sample_transforms,
                                         batch_transforms, batch_size, shuffle,
                                         drop_last, drop_empty, num_classes,
                                         with_background, **kwargs)