dist_loader.py 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
17 18
from functools import wraps

19
import paddle
20
from .utils import to_list
21
from paddle.fluid.layers.utils import flatten
22
from paddle.io import DataLoader, BatchSampler, IterableDataset
23
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
24
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
25 26 27


class DistributedDataLoader(metaclass=abc.ABCMeta):
28

29 30 31 32 33 34
    def __init__(self,
                 dataset,
                 batch_size=1,
                 epochs=1,
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
35 36
                 drop_last=False,
                 split_data=True):
37
        if isinstance(dataset, IterableDataset):
38
            self.dataset_kind = _DatasetKind.ITER
39 40 41
        else:
            self.dataset_kind = _DatasetKind.MAP

42 43 44
        self.dataset = dataset
        self.epochs = epochs
        self.drop_lost = drop_last
45 46 47
        self.data_parallel_world_size = data_parallel_world_size
        self.data_parallel_rank = data_parallel_rank
        self.split_data = split_data
48 49 50 51 52 53

        if batch_size is None:
            self.batch_size = None
            self.batch_sampler = None
        else:
            if data_parallel_world_size is not None:
54 55 56 57
                for dp_world_size in data_parallel_world_size:
                    if dp_world_size is not None:
                        assert batch_size % dp_world_size == 0, \
                            "batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size))
58
            self.batch_size = batch_size
59 60 61 62 63 64 65 66
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size)
            else:
                self.batch_sampler = BatchSampler(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=drop_last)
67 68 69

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
70 71 72 73 74 75 76 77 78

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError

79 80 81 82 83 84 85 86
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
87
                return _InfiniteIterableSampler(self.dataset, 1)
88

89 90

class NonIterableGeneratorLoader(DistributedDataLoader):
91

92 93 94 95 96 97
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
98
                 steps_per_epoch=None,
99
                 collate_fn=None,
100 101
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
102 103
                 drop_last=False,
                 split_data=True):
104 105 106
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
107

108 109 110
        super(NonIterableGeneratorLoader,
              self).__init__(dataset, batch_size, epochs,
                             data_parallel_world_size, data_parallel_rank,
111
                             drop_last, split_data)
112 113 114 115 116 117 118 119 120

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
            self.collate_fn, self.drop_lost)

121
        self._steps = self._infer_steps()
122
        self._inner_dataloader = self._create_inner_dataloader()
123 124 125 126 127 128 129

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
130 131 132
        if not self._steps:
            self._cur_step += 1
        elif self._cur_step < self._steps:
133 134 135
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
136
            self.sampler_iter = iter(self.index_sampler)
137 138
            raise StopIteration

139 140 141 142
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
143 144 145
            if isinstance(self.dataset, IterableDataset):
                steps_per_epoch = None
            elif self.batch_size is None:
146 147 148
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
149 150 151 152 153 154
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

155
    def _create_inner_dataloader(self):
156

157
        def sample_data_generator():
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
            while True:
                try:
                    indices = next(self.sampler_iter)
                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break

                except StopIteration:
                    self.dataset_fetcher = _DatasetKind.create_fetcher(
                        self.dataset_kind, self.dataset,
                        self.auto_collate_batch, self.collate_fn,
                        self.drop_lost)
                    break

                partial_data = []
                for i, d in enumerate(batch[:len(self.feed_list)]):
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
                    elif self.dp_world_sizes[i] is not None:
                        partial_data.append(
                            np.split(array,
                                     self.dp_world_sizes[i])[self.dp_ranks[i]])
                    else:
                        partial_data.append(array)
                yield partial_data
183

184
        def batch_data_generator():
185 186 187 188 189 190 191 192 193
            while True:
                try:
                    indices = next(self.sampler_iter)

                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break
                except StopIteration:
                    break

194
                partial_data = []
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
                for i, d in enumerate(batch[:len(self.feed_list)]):
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
                    elif self.dp_world_sizes[i] is not None:
                        partial_data.append(
                            np.split(array,
                                     self.dp_world_sizes[i])[self.dp_ranks[i]])
                    else:
                        partial_data.append(array)
                yield partial_data

        self.dp_world_sizes = [
            1 for _ in range(len(self.feed_list))
        ] if self.data_parallel_world_size is None else self.data_parallel_world_size
        self.dp_ranks = [
            0 for _ in range(len(self.feed_list))
        ] if self.data_parallel_rank is None else self.data_parallel_rank
213 214 215

        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
216
        if self.batch_size is not None:
217 218 219 220
            dataloader.set_batch_generator(sample_data_generator, self.places)
        else:
            dataloader.set_batch_generator(batch_data_generator, self.places)

221
        return dataloader