dist_loader.py 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
17

18
import paddle
19
from paddle.io import BatchSampler, IterableDataset
20
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
21
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
22 23


24
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
25 26 27 28 29 30 31 32 33 34

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError


35
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
36

37 38
    def __init__(self,
                 dataset,
39 40 41 42 43 44 45 46
                 feed_list=None,
                 capacity=None,
                 use_double_buffer=True,
                 iterable=True,
                 return_list=False,
                 use_multiprocess=False,
                 drop_last=True,
                 places=None,
47 48
                 batch_size=1,
                 epochs=1,
49
                 steps_per_epoch=None,
50
                 collate_fn=None,
51
                 split_data=True,
52
                 data_parallel_world_size=[],
53 54
                 data_parallel_rank=[]):
        self.dataset = dataset
55
        self.feed_list = feed_list
56 57 58 59 60 61
        self.capacity = capacity
        self.use_double_buffer = use_double_buffer
        self.iterable = iterable
        self.return_list = return_list
        self.use_multiprocess = use_multiprocess
        self.drop_last = drop_last
62
        self.places = places
63 64
        self.batch_size = batch_size
        self.epochs = epochs
65
        self.steps_per_epoch = steps_per_epoch
66 67
        self.collate_fn = collate_fn
        self.split_data = split_data
68 69 70 71 72
        assert len(data_parallel_world_size) == len(feed_list)
        assert len(data_parallel_rank) == len(feed_list)
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        if isinstance(dataset, IterableDataset):
            self.dataset_kind = _DatasetKind.ITER
        else:
            self.dataset_kind = _DatasetKind.MAP

        if self.batch_size is None:
            self.batch_sampler = None
        else:
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size)
            else:
                self.batch_sampler = BatchSampler(dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=drop_last)

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
92 93 94 95 96

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
97

98 99
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
100
            self.collate_fn, self.drop_last)
101

102
        self._steps = self._infer_steps()
103
        self._inner_dataloader = self._create_inner_dataloader()
104 105 106 107 108 109 110

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
111 112
        if not self._steps:
            self._cur_step += 1
113
            return None
114
        elif self._cur_step < self._steps:
115
            self._cur_step += 1
116
            return None
117 118
        else:
            self._inner_dataloader.reset()
119
            self.sampler_iter = iter(self.index_sampler)
120 121
            raise StopIteration

122 123 124 125
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
126 127 128
            if isinstance(self.dataset, IterableDataset):
                steps_per_epoch = None
            elif self.batch_size is None:
129 130 131
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
132 133 134 135 136 137
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

138 139 140 141 142 143 144 145 146 147
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
                return _InfiniteIterableSampler(self.dataset, 1)

148
    def _create_inner_dataloader(self):
149

150
        def data_generator():
151 152 153 154 155 156 157 158 159
            while True:
                try:
                    indices = next(self.sampler_iter)
                    batch = self.dataset_fetcher.fetch(indices)
                    if batch is None: break
                except StopIteration:
                    self.dataset_fetcher = _DatasetKind.create_fetcher(
                        self.dataset_kind, self.dataset,
                        self.auto_collate_batch, self.collate_fn,
160
                        self.drop_last)
161 162 163
                    break

                partial_data = []
164
                for i, d in enumerate(batch):
165 166 167
                    array = np.array(d)
                    if not self.split_data:
                        partial_data.append(array)
168
                        continue
169

170 171 172 173 174 175
                    batch_size = array.shape[0]
                    assert batch_size % self.dp_world_sizes[i] == 0, \
                        "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
                    partial_data.append(
                        np.split(array,
                                 self.dp_world_sizes[i])[self.dp_ranks[i]])
176 177 178

                yield partial_data

179
        dataloader = paddle.fluid.io.DataLoader.from_generator(
180 181 182 183 184 185 186 187
            feed_list=self.feed_list,
            capacity=self.capacity,
            use_double_buffer=self.use_double_buffer,
            # iterable=self.iterable,
            iterable=False,
            return_list=self.return_list,
            use_multiprocess=self.use_multiprocess,
            drop_last=self.drop_last)
188
        dataloader.set_batch_generator(data_generator, self.places)
189

190
        return dataloader
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259


class DistributedDataLoader(DistributedDataLoaderBase):

    def __init__(self,
                 dataset,
                 feed_list=None,
                 places=None,
                 return_list=True,
                 batch_size=1,
                 shuffle=False,
                 drop_last=False,
                 collate_fn=None,
                 num_workers=0,
                 use_buffer_reader=True,
                 use_shared_memory=True,
                 timeout=0,
                 worker_init_fn=None,
                 epochs=1,
                 steps_per_epoch=None,
                 split_data=True,
                 data_parallel_world_size=[],
                 data_parallel_rank=[]):
        self.dataset = dataset
        self.feed_list = feed_list
        self.return_list = return_list
        self.places = places
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.collate_fn = collate_fn
        self.num_workers = num_workers
        self.use_buffer_reader = use_buffer_reader
        self.use_shared_memory = use_shared_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.dp_world_sizes = data_parallel_world_size
        self.dp_ranks = data_parallel_rank
        self.split_data = split_data
        # TODO: rank info
        self.batch_sampler = DistributedBatchSampler(
            self.dataset, self.batch_size, self.dp_world_sizes[0],
            self.dp_ranks[0], self.shuffle, self.drop_last)
        self._inner_dataloader = self._create_inner_dataloader()

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)

    def _create_inner_dataloader(self):
        dataloader = paddle.fluid.io.DataLoader(
            self.dataset,
            feed_list=self.feed_list,
            places=self.places,
            return_list=self.return_list,
            batch_sampler=self.batch_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers,
            use_buffer_reader=self.use_buffer_reader,
            use_shared_memory=self.use_shared_memory,
            timeout=self.timeout,
            worker_init_fn=self.worker_init_fn)
        self.data = (x for x in dataloader)

        return dataloader