reader.py 6.3 KB
Newer Older
1 2
import copy
import traceback
3
import logging
G
Guanghua Yu 已提交
4 5 6 7 8 9 10 11 12 13
import threading
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
14 15
from . import transform
from .transform import operator, batch_operator
16 17 18 19

logger = logging.getLogger(__name__)


20
class Compose(object):
21
    def __init__(self, transforms, fields=None, from_=transform,
G
Guanghua Yu 已提交
22
                 num_classes=81):
23
        self.transforms = transforms
G
Guanghua Yu 已提交
24
        self.transforms_cls = []
25
        output_fields = None
G
Guanghua Yu 已提交
26 27 28 29 30 31 32
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

33 34 35 36 37 38 39 40 41
                if op_cls in [
                        transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
                ]:
                    output_fields = ['image', 'gt_bbox']
                    output_fields.extend([
                        'target{}'.format(i)
                        for i in range(len(v['anchor_masks']))
                    ])

G
Guanghua Yu 已提交
42
        self.fields = fields
43
        self.output_fields = output_fields if output_fields else fields
44 45

    def __call__(self, data):
G
Guanghua Yu 已提交
46 47 48 49 50 51 52
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
53
            try:
G
Guanghua Yu 已提交
54
                data = f(data)
55 56
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
57
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
58 59 60
                            format(f, e, str(stack_info)))
                raise e

61
        if self.output_fields is not None:
G
Guanghua Yu 已提交
62 63 64
            data_new = []
            for item in data:
                batch = []
65
                for k in self.output_fields:
G
Guanghua Yu 已提交
66 67 68 69 70 71 72 73
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
74
            else:
G
Guanghua Yu 已提交
75
                data = data_new
76

G
Guanghua Yu 已提交
77
        return data
78 79


G
Guanghua Yu 已提交
80 81
class BaseDataLoader(object):
    __share__ = ['num_classes']
82 83

    def __init__(self,
G
Guanghua Yu 已提交
84
                 inputs_def=None,
85 86
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
87
                 batch_size=1,
88 89 90
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
91 92 93
                 num_classes=81,
                 with_background=True):
        # out fields 
94
        self._fields = inputs_def['fields'] if inputs_def else None
G
Guanghua Yu 已提交
95 96 97
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
98

G
Guanghua Yu 已提交
99
        # batch transfrom 
100 101
        self._batch_transforms = None
        if batch_transforms:
102 103 104
            self._batch_transforms = Compose(batch_transforms,
                                             copy.deepcopy(self._fields),
                                             transform, num_classes)
G
Guanghua Yu 已提交
105 106 107 108 109 110 111 112 113

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background

    def __call__(self,
                 dataset,
                 worker_num,
114 115
                 device=None,
                 batch_sampler=None,
G
Guanghua Yu 已提交
116 117 118 119 120
                 return_list=False,
                 use_prefetch=True):
        self._dataset = dataset
        self._dataset.parse_dataset(self.with_background)
        # get data
121 122
        self._dataset.set_out(self._sample_transforms,
                              copy.deepcopy(self._fields))
G
Guanghua Yu 已提交
123
        # batch sampler
124 125 126 127 128 129 130 131
        if batch_sampler is None:
            self._batch_sampler = DistributedBatchSampler(
                self._dataset,
                batch_size=self.batch_size,
                shuffle=self.shuffle,
                drop_last=self.drop_last)
        else:
            self._batch_sampler = batch_sampler
G
Guanghua Yu 已提交
132 133 134 135 136 137 138 139 140 141 142 143

        loader = DataLoader(
            dataset=self._dataset,
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            places=device,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)

        return loader, len(self._batch_sampler)
144 145


G
Guanghua Yu 已提交
146 147 148 149 150 151 152 153
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
154
                 drop_last=True,
G
Guanghua Yu 已提交
155 156 157 158 159 160
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(TrainReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)
161

K
Kaipeng Deng 已提交
162

G
Guanghua Yu 已提交
163 164 165 166 167 168 169 170
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
171
                 drop_last=True,
G
Guanghua Yu 已提交
172 173 174 175 176 177
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(EvalReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)
178

179

G
Guanghua Yu 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(TestReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)