dist_loader.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
17

18
import paddle
19
from paddle.io import BatchSampler, IterableDataset
20
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
21
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
22 23 24


class DistributedDataLoader(metaclass=abc.ABCMeta):
25

26
    def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
27
        if isinstance(dataset, IterableDataset):
28
            self.dataset_kind = _DatasetKind.ITER
29 30 31
        else:
            self.dataset_kind = _DatasetKind.MAP

32 33
        self.dataset = dataset
        self.epochs = epochs
34
        self.drop_last = drop_last
35 36 37 38 39 40

        if batch_size is None:
            self.batch_size = None
            self.batch_sampler = None
        else:
            self.batch_size = batch_size
41 42 43 44 45 46 47 48
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size)
            else:
                self.batch_sampler = BatchSampler(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=drop_last)
49 50 51

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
52 53 54 55 56 57 58 59 60

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError

61 62 63 64 65 66 67 68
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
69
                return _InfiniteIterableSampler(self.dataset, 1)
70

71 72

class NonIterableGeneratorLoader(DistributedDataLoader):
73

74 75 76 77 78 79
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
80
                 steps_per_epoch=None,
81
                 collate_fn=None,
82 83
                 data_parallel_world_size=[],
                 data_parallel_rank=[],
84 85
                 drop_last=False,
                 split_data=True):
86 87 88
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
89

90 91 92 93 94 95
        assert len(data_parallel_world_size) == len(feed_list)
        assert len(data_parallel_rank) == len(feed_list)
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank
        self.split_data = split_data

96
        super(NonIterableGeneratorLoader,
97
              self).__init__(dataset, batch_size, epochs, drop_last)
98 99 100 101 102 103 104

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
105
            self.collate_fn, self.drop_last)
106

107
        self._steps = self._infer_steps()
108
        self._inner_dataloader = self._create_inner_dataloader()
109 110 111 112 113 114 115

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
116 117 118
        if not self._steps:
            self._cur_step += 1
        elif self._cur_step < self._steps:
119 120 121
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
122
            self.sampler_iter = iter(self.index_sampler)
123 124
            raise StopIteration

125 126 127 128
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
129 130 131
            if isinstance(self.dataset, IterableDataset):
                steps_per_epoch = None
            elif self.batch_size is None:
132 133 134
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
135 136 137 138 139 140
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

141
    def _create_inner_dataloader(self):
142

143
        def data_generator():
144 145 146 147 148 149 150 151 152
            while True:
                try:
                    indices = next(self.sampler_iter)
                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break
                except StopIteration:
                    self.dataset_fetcher = _DatasetKind.create_fetcher(
                        self.dataset_kind, self.dataset,
                        self.auto_collate_batch, self.collate_fn,
153
                        self.drop_last)
154 155 156
                    break

                partial_data = []
157
                for i, d in enumerate(batch):
158 159 160
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
161
                        continue
162

163 164 165 166 167 168
                    batch_size = array.shape[0]
                    assert batch_size % self.dp_world_sizes[i] == 0, \
                        "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
                    partial_data.append(
                        np.split(array,
                                 self.dp_world_sizes[i])[self.dp_ranks[i]])
169 170 171

                yield partial_data

172 173
        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
174
        dataloader.set_batch_generator(data_generator, self.places)
175

176
        return dataloader