dist_loader.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
17

18
import paddle
19 20
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
21
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
22 23


24
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
25 26 27 28 29 30 31 32 33

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError

34 35 36 37 38 39 40 41
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
42
                return _InfiniteIterableSampler(self.dataset, 1)
43

44

45
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
46

47 48
    def __init__(self,
                 dataset,
49 50 51 52 53 54 55 56
                 feed_list=None,
                 capacity=None,
                 use_double_buffer=True,
                 iterable=True,
                 return_list=False,
                 use_multiprocess=False,
                 drop_last=True,
                 places=None,
57 58
                 batch_size=1,
                 epochs=1,
59
                 steps_per_epoch=None,
60
                 collate_fn=None,
61
                 split_data=True,
62
                 data_parallel_world_size=[],
63 64
                 data_parallel_rank=[]):
        self.dataset = dataset
65
        self.feed_list = feed_list
66 67 68 69 70 71
        self.capacity = capacity
        self.use_double_buffer = use_double_buffer
        self.iterable = iterable
        self.return_list = return_list
        self.use_multiprocess = use_multiprocess
        self.drop_last = drop_last
72
        self.places = places
73 74
        self.batch_size = batch_size
        self.epochs = epochs
75
        self.steps_per_epoch = steps_per_epoch
76 77
        self.collate_fn = collate_fn
        self.split_data = split_data
78 79 80 81 82
        assert len(data_parallel_world_size) == len(feed_list)
        assert len(data_parallel_rank) == len(feed_list)
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        if isinstance(dataset, IterableDataset):
            self.dataset_kind = _DatasetKind.ITER
        else:
            self.dataset_kind = _DatasetKind.MAP

        if self.batch_size is None:
            self.batch_sampler = None
        else:
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size)
            else:
                self.batch_sampler = BatchSampler(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=drop_last)

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
102 103 104 105 106

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
107

108 109
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
110
            self.collate_fn, self.drop_last)
111

112
        self._steps = self._infer_steps()
113
        self._inner_dataloader = self._create_inner_dataloader()
114 115 116 117 118 119 120

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
121 122
        if not self._steps:
            self._cur_step += 1
123
            return None
124
        elif self._cur_step < self._steps:
125
            self._cur_step += 1
126
            return None
127 128
        else:
            self._inner_dataloader.reset()
129
            self.sampler_iter = iter(self.index_sampler)
130 131
            raise StopIteration

132 133 134 135
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
136 137 138
            if isinstance(self.dataset, IterableDataset):
                steps_per_epoch = None
            elif self.batch_size is None:
139 140 141
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
142 143 144 145 146 147
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

148 149 150 151 152 153 154 155 156 157
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
                return _InfiniteIterableSampler(self.dataset, 1)

158
    def _create_inner_dataloader(self):
159

160
        def data_generator():
161 162 163 164 165 166 167 168 169
            while True:
                try:
                    indices = next(self.sampler_iter)
                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break
                except StopIteration:
                    self.dataset_fetcher = _DatasetKind.create_fetcher(
                        self.dataset_kind, self.dataset,
                        self.auto_collate_batch, self.collate_fn,
170
                        self.drop_last)
171 172 173
                    break

                partial_data = []
174
                for i, d in enumerate(batch):
175 176 177
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
178
                        continue
179

180 181 182 183 184 185
                    batch_size = array.shape[0]
                    assert batch_size % self.dp_world_sizes[i] == 0, \
                        "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
                    partial_data.append(
                        np.split(array,
                                 self.dp_world_sizes[i])[self.dp_ranks[i]])
186 187 188

                yield partial_data

189
        dataloader = paddle.fluid.io.DataLoader.from_generator(
190 191 192 193 194 195 196 197
            feed_list=self.feed_list,
            capacity=self.capacity,
            use_double_buffer=self.use_double_buffer,
            # iterable=self.iterable,
            iterable=False,
            return_list=self.return_list,
            use_multiprocess=self.use_multiprocess,
            drop_last=self.drop_last)
198
        dataloader.set_batch_generator(data_generator, self.places)
199

200
        return dataloader
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269


class DistributedDataLoader(DistributedDataLoaderBase):

    def __init__(self,
                 dataset,
                 feed_list=None,
                 places=None,
                 return_list=True,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 collate_fn=None,
                 num_workers=0,
                 use_buffer_reader=True,
                 use_shared_memory=True,
                 timeout=0,
                 worker_init_fn=None,
                 epochs=1,
                 steps_per_epoch=None,
                 split_data=True,
                 data_parallel_world_size=[],
                 data_parallel_rank=[]):
        self.dataset = dataset
        self.feed_list = feed_list
        self.return_list = return_list
        self.places = places
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.collate_fn = collate_fn
        self.num_workers = num_workers
        self.use_buffer_reader = use_buffer_reader
        self.use_shared_memory = use_shared_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank
        self.split_data = split_data
        # TODO: rank info
        self.batch_sampler = DistributedBatchSampler(
            self.dataset, self.batch_size, self.dp_world_sizes[0],
            self.dp_ranks[0], self.shuffle, self.drop_last)
        self._inner_dataloader = self._create_inner_dataloader()

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)

    def _create_inner_dataloader(self):
        dataloader = paddle.fluid.io.DataLoader(
            self.dataset,
            feed_list=self.feed_list,
            places=self.places,
            return_list=self.return_list,
            batch_sampler=self.batch_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers,
            use_buffer_reader=self.use_buffer_reader,
            use_shared_memory=self.use_shared_memory,
            timeout=self.timeout,
            worker_init_fn=self.worker_init_fn)
        self.data = (x for x in dataloader)

        return dataloader