reader.py 5.6 KB
Newer Older
1 2
import copy
import traceback
3
import logging
G
Guanghua Yu 已提交
4 5 6 7 8 9 10 11 12 13 14 15
import threading
import sys
if sys.version_info >= (3, 0):
    import queue as Queue
else:
    import Queue
import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
from .transform import operators
from .transform import batch_operators
16 17 18 19

logger = logging.getLogger(__name__)


20
class Compose(object):
G
Guanghua Yu 已提交
21 22
    def __init__(self, transforms, fields=None, from_=operators,
                 num_classes=81):
23
        self.transforms = transforms
G
Guanghua Yu 已提交
24 25 26 27 28 29 30 31 32
        self.transforms_cls = []
        for t in self.transforms:
            for k, v in t.items():
                op_cls = getattr(from_, k)
                self.transforms_cls.append(op_cls(**v))
                if hasattr(op_cls, 'num_classes'):
                    op_cls.num_classes = num_classes

        self.fields = fields
33 34

    def __call__(self, data):
G
Guanghua Yu 已提交
35 36 37 38 39 40 41
        if self.fields is not None:
            data_new = []
            for item in data:
                data_new.append(dict(zip(self.fields, item)))
            data = data_new

        for f in self.transforms_cls:
42
            try:
G
Guanghua Yu 已提交
43
                data = f(data)
44 45
            except Exception as e:
                stack_info = traceback.format_exc()
G
Guanghua Yu 已提交
46
                logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
47 48 49
                            format(f, e, str(stack_info)))
                raise e

G
Guanghua Yu 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62
        if self.fields is not None:
            data_new = []
            for item in data:
                batch = []
                for k in self.fields:
                    batch.append(item[k])
                data_new.append(batch)
            batch_size = len(data_new)
            data_new = list(zip(*data_new))
            if batch_size > 1:
                data = [
                    np.array(item).astype(item[0].dtype) for item in data_new
                ]
63
            else:
G
Guanghua Yu 已提交
64
                data = data_new
65

G
Guanghua Yu 已提交
66
        return data
67 68


G
Guanghua Yu 已提交
69 70
class BaseDataLoader(object):
    __share__ = ['num_classes']
71 72

    def __init__(self,
G
Guanghua Yu 已提交
73
                 inputs_def=None,
74 75
                 sample_transforms=None,
                 batch_transforms=None,
G
Guanghua Yu 已提交
76
                 batch_size=1,
77 78 79
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
G
Guanghua Yu 已提交
80 81 82
                 num_classes=81,
                 with_background=True):
        # out fields 
83 84
        self._fields = copy.deepcopy(inputs_def[
            'fields']) if inputs_def else None
G
Guanghua Yu 已提交
85 86 87
        # sample transform
        self._sample_transforms = Compose(
            sample_transforms, num_classes=num_classes)
88

G
Guanghua Yu 已提交
89
        # batch transfrom 
90 91
        self._batch_transforms = None
        if batch_transforms:
G
Guanghua Yu 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            self._batch_transforms = Compose(batch_transforms, self._fields,
                                             batch_operators, num_classes)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.with_background = with_background

    def __call__(self,
                 dataset,
                 worker_num,
                 device,
                 return_list=False,
                 use_prefetch=True):
        self._dataset = dataset
        self._dataset.parse_dataset(self.with_background)
        # get data
        self._dataset.set_out(self._sample_transforms, self._fields)
        # batch sampler
        self._batch_sampler = DistributedBatchSampler(
            self._dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            drop_last=self.drop_last)

        loader = DataLoader(
            dataset=self._dataset,
            batch_sampler=self._batch_sampler,
            collate_fn=self._batch_transforms,
            num_workers=worker_num,
            places=device,
            return_list=return_list,
            use_buffer_reader=use_prefetch,
            use_shared_memory=False)

        return loader, len(self._batch_sampler)
128 129


G
Guanghua Yu 已提交
130 131 132 133 134 135 136 137
@register
class TrainReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=True,
W
wangguanzhong 已提交
138
                 drop_last=True,
G
Guanghua Yu 已提交
139 140 141 142 143 144
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(TrainReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)
145

K
Kaipeng Deng 已提交
146

G
Guanghua Yu 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
@register
class EvalReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(EvalReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)
162

163

G
Guanghua Yu 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
@register
class TestReader(BaseDataLoader):
    def __init__(self,
                 inputs_def=None,
                 sample_transforms=None,
                 batch_transforms=None,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 drop_empty=True,
                 num_classes=81,
                 with_background=True):
        super(TestReader, self).__init__(
            inputs_def, sample_transforms, batch_transforms, batch_size,
            shuffle, drop_last, drop_empty, num_classes, with_background)