datacargo.py 3.5 KB
Newer Older
L
lifuchen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import six
16
from .sampler import SequentialSampler, RandomSampler, BatchSampler
17

18

19
class DataCargo(object):
20 21 22 23 24 25 26
    def __init__(self,
                 dataset,
                 batch_fn=None,
                 batch_size=1,
                 sampler=None,
                 shuffle=False,
                 batch_sampler=None,
L
lifuchen 已提交
27
                 drop_last=False):
28
        self.dataset = dataset
29 30
        self.batch_fn = batch_fn or self.dataset._batch_examples

31 32 33 34 35 36 37 38
        if batch_sampler is not None:
            # auto_collation with custom batch_sampler
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False
39
            shuffle = False
40
        elif batch_size is None:
41 42
            raise ValueError(
                'batch sampler is none. then batch size must not be none.')
43
        elif sampler is None:
44 45 46 47 48
            if shuffle:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
L
lifuchen 已提交
49 50 51
        else:
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

52 53 54
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
L
lifuchen 已提交
55

56
        self.batch_sampler = batch_sampler
57

58 59
    def __iter__(self):
        return DataIterator(self)
K
Kexin Zhao 已提交
60 61 62

    def __call__(self):
        return DataIterator(self)
63

64 65 66 67 68 69 70 71 72 73 74 75 76
    @property
    def _auto_collation(self):
        # we will auto batching
        return self.batch_sampler is not None

    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler

    def __len__(self):
C
chenfeiyu 已提交
77
        return len(self._index_sampler)
78 79


80 81 82 83
class DataIterator(object):
    def __init__(self, loader):
        self.loader = loader
        self._dataset = loader.dataset
84 85

        self._batch_fn = loader.batch_fn
86 87
        self._index_sampler = loader._index_sampler
        self._sampler_iter = iter(self._index_sampler)
88

89 90
    def __iter__(self):
        return self
91

92
    def __next__(self):
L
lifuchen 已提交
93

94 95 96 97 98
        index = self._next_index(
        )  # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
        minibatch = [self._dataset[i] for i in index
                     ]  # we can abstract it, too to use dynamic batch size
        minibatch = self._batch_fn(minibatch)  # list[Example] -> Batch
99
        return minibatch
100

101 102
    next = __next__  # Python 2 compatibility

103
    def _next_index(self):
104 105 106 107 108
        if six.PY3:
            return next(self._sampler_iter)
        else:
            # six.PY2
            return self._sampler_iter.next()
109

110 111
    def __len__(self):
        return len(self._index_sampler)