dist_loader.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
17 18
from functools import wraps

19
import paddle
20
from .utils import to_list
21
from paddle.fluid.layers.utils import flatten
22
from paddle.io import DataLoader, BatchSampler, IterableDataset
23
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
24
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
25 26 27


class DistributedDataLoader(metaclass=abc.ABCMeta):
28

29
    def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
30
        if isinstance(dataset, IterableDataset):
31
            self.dataset_kind = _DatasetKind.ITER
32 33 34
        else:
            self.dataset_kind = _DatasetKind.MAP

35 36
        self.dataset = dataset
        self.epochs = epochs
37
        self.drop_last = drop_last
38 39 40 41 42 43

        if batch_size is None:
            self.batch_size = None
            self.batch_sampler = None
        else:
            self.batch_size = batch_size
44 45 46 47 48 49 50 51
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size)
            else:
                self.batch_sampler = BatchSampler(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=drop_last)
52 53 54

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
55 56 57 58 59 60 61 62 63

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError

64 65 66 67 68 69 70 71
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
72
                return _InfiniteIterableSampler(self.dataset, 1)
73

74 75

class NonIterableGeneratorLoader(DistributedDataLoader):
76

77 78 79 80 81 82
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
83
                 steps_per_epoch=None,
84
                 collate_fn=None,
85 86
                 data_parallel_world_size=[],
                 data_parallel_rank=[],
87 88
                 drop_last=False,
                 split_data=True):
89 90 91
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
92

93 94 95 96 97 98
        assert len(data_parallel_world_size) == len(feed_list)
        assert len(data_parallel_rank) == len(feed_list)
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank
        self.split_data = split_data

99
        super(NonIterableGeneratorLoader,
100
              self).__init__(dataset, batch_size, epochs, drop_last)
101 102 103 104 105 106 107

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
108
            self.collate_fn, self.drop_last)
109

110
        self._steps = self._infer_steps()
111
        self._inner_dataloader = self._create_inner_dataloader()
112 113 114 115 116 117 118

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
119 120 121
        if not self._steps:
            self._cur_step += 1
        elif self._cur_step < self._steps:
122 123 124
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
125
            self.sampler_iter = iter(self.index_sampler)
126 127
            raise StopIteration

128 129 130 131
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
132 133 134
            if isinstance(self.dataset, IterableDataset):
                steps_per_epoch = None
            elif self.batch_size is None:
135 136 137
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
138 139 140 141 142 143
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

144
    def _create_inner_dataloader(self):
145

146
        def data_generator():
147 148 149 150 151 152 153 154 155
            while True:
                try:
                    indices = next(self.sampler_iter)
                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break
                except StopIteration:
                    self.dataset_fetcher = _DatasetKind.create_fetcher(
                        self.dataset_kind, self.dataset,
                        self.auto_collate_batch, self.collate_fn,
156
                        self.drop_last)
157 158 159
                    break

                partial_data = []
160
                for i, d in enumerate(batch):
161 162 163
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
164
                        continue
165

166 167 168 169 170 171
                    batch_size = array.shape[0]
                    assert batch_size % self.dp_world_sizes[i] == 0, \
                        "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
                    partial_data.append(
                        np.split(array,
                                 self.dp_world_sizes[i])[self.dp_ranks[i]])
172 173 174

                yield partial_data

175 176
        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
177
        dataloader.set_batch_generator(data_generator, self.places)
178

179
        return dataloader