datacargo.py 2.4 KB
Newer Older
1
from .sampler import SequentialSampler, RandomSampler, BatchSampler
2

3 4 5
class DataCargo(object):
    def __init__(self, dataset, batch_size=1, sampler=None, 
                 shuffle=False, batch_sampler=None, drop_last=False):
6 7 8 9 10 11 12 13 14 15
        self.dataset = dataset
        
        if batch_sampler is not None:
            # auto_collation with custom batch_sampler
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False
16
            shuffle = False
17
        elif batch_size is None:
18 19
            raise ValueError('batch sampler is none. then batch size must not be none.')
        elif sampler is None:
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
            if shuffle:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = batch_sampler
    
    def __iter__(self):
        return DataIterator(self)
    
    @property
    def _auto_collation(self):
        # we will auto batching
        return self.batch_sampler is not None

    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler

    def __len__(self):
C
chenfeiyu 已提交
48
        return len(self._index_sampler)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    
class DataIterator(object):
    def __init__(self, loader):
        self.loader = loader
        self._dataset = loader.dataset
        
        self._index_sampler = loader._index_sampler
        self._sampler_iter = iter(self._index_sampler)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        index = self._next_index()  # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size
        minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size
64
        minibatch = self._dataset._batch_examples(minibatch) # list[Example] -> Batch
65 66 67 68 69 70 71
        return minibatch
    
    def _next_index(self):
        return next(self._sampler_iter)
    
    def __len__(self):
        return len(self._index_sampler)