dist_loader.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
import paddle
18
from .utils import to_list
19
from paddle.fluid.layers.utils import flatten
20 21
from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
22 23 24


class DistributedDataLoader(metaclass=abc.ABCMeta):
25

26 27 28 29 30 31 32
    def __init__(self,
                 dataset,
                 batch_size=1,
                 epochs=1,
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
                 drop_last=False):
33 34 35 36 37
        if isinstance(dataset, IterableDataset):
            raise TypeError("IterableDataset is not supported.")
        else:
            self.dataset_kind = _DatasetKind.MAP

38 39 40
        self.dataset = dataset
        self.epochs = epochs
        self.drop_lost = drop_last
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

        if batch_size is None:
            self.batch_size = None
            self.batch_sampler = None
        else:
            if data_parallel_world_size is not None:
                assert batch_size % data_parallel_world_size == 0, \
                    "'batch_size' must be divisible by data parallel size"
            self.batch_size = batch_size
            self.batch_sampler = BatchSampler(dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=drop_last)

        self.auto_collate_batch = self.batch_sampler is not None
        self.sampler_iter = iter(self.index_sampler)
        self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
        self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
59 60 61 62 63 64 65 66 67

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError

68 69 70 71 72 73 74 75 76 77
    @property
    def index_sampler(self):
        if self.auto_collate_batch:
            return self.batch_sampler
        else:
            if self.dataset_kind == _DatasetKind.MAP:
                return list(range(len(self.dataset)))
            else:
                raise TypeError("Only support datasets in map-style.")

78 79

class NonIterableGeneratorLoader(DistributedDataLoader):
80

81 82 83 84 85 86
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
87
                 steps_per_epoch=None,
88
                 collate_fn=None,
89 90
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
91
                 drop_last=False):
92 93 94
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
95

96 97 98 99
        super(NonIterableGeneratorLoader,
              self).__init__(dataset, batch_size, epochs,
                             data_parallel_world_size, data_parallel_rank,
                             drop_last)
100 101 102 103 104 105 106 107 108

        if self.auto_collate_batch:
            self.collate_fn = collate_fn or default_collate_fn
        else:
            self.collate_fn = collate_fn or default_convert_fn
        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collate_batch,
            self.collate_fn, self.drop_lost)

109
        self._steps = self._infer_steps()
110
        self._inner_dataloader = self._create_inner_dataloader()
111 112 113 114 115 116 117

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
118
        if self._cur_step < self._steps:
119 120 121 122 123
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
            raise StopIteration

124 125 126 127
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
128 129 130 131
            if self.batch_size is None:
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
132 133 134 135 136 137
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

138
    def _create_inner_dataloader(self):
139

140
        def sample_data_generator():
141 142 143 144 145 146 147 148 149
            for indices in self.sampler_iter:
                assert len(indices) % self.dp_world_size == 0, \
                    "Please set batch_size to be divisible by data parallel size"
                n = len(indices) // self.dp_world_size
                cur_indices = [
                    indices[i:i + n] for i in range(0, len(indices), n)
                ]
                batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank])
                yield batch[:len(self.feed_list)]
150

151
        def batch_data_generator():
152
            for indices in self.sampler_iter:
153
                partial_data = []
154 155 156 157
                batch = self.dataset_fetcher.fetch(indices)
                for data in batch:
                    assert data.shape[0] % self.dp_world_size == 0, \
                        "Please padding dataset's batch_size to be divisible by data parallel size"
158
                    partial_data.append(
159
                        np.split(data, self.dp_world_size)[self.dp_rank])
160
                yield partial_data[:len(self.feed_list)]
161 162 163

        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
164
        if self.batch_size is not None:
165 166 167 168
            dataloader.set_batch_generator(sample_data_generator, self.places)
        else:
            dataloader.set_batch_generator(batch_data_generator, self.places)

169
        return dataloader