dist_loader.py 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
import paddle
18
from .utils import to_list
19
from paddle.fluid.layers.utils import flatten
20 21 22 23
from paddle.io import DataLoader, DistributedBatchSampler


class DistributedDataLoader(metaclass=abc.ABCMeta):
24

25 26 27 28 29 30 31 32 33 34 35 36 37
    def __init__(self,
                 dataset,
                 batch_size=1,
                 epochs=1,
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        self.data_parallel_world_size = data_parallel_world_size
        self.data_parallel_rank = data_parallel_rank
        self.drop_lost = drop_last
38
        if data_parallel_world_size is not None and batch_size is not None:
39 40 41 42 43 44 45 46 47 48 49 50
            assert batch_size % data_parallel_world_size == 0

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError


class NonIterableGeneratorLoader(DistributedDataLoader):
51

52 53 54 55 56 57
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
58
                 steps_per_epoch=None,
59 60
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
61
                 drop_last=False):
62 63 64
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
65 66
        self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
        self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
67

68 69 70 71
        super(NonIterableGeneratorLoader,
              self).__init__(dataset, batch_size, epochs,
                             data_parallel_world_size, data_parallel_rank,
                             drop_last)
72
        self._inner_dataloader = self._create_inner_dataloader()
73
        self._steps = self._infer_steps()
74 75 76 77 78 79 80

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
81
        if self._cur_step < self._steps:
82 83 84 85 86
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
            raise StopIteration

87 88 89 90
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
91 92 93 94
            if self.batch_size is None:
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
95 96 97 98 99 100
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

101
    def _create_inner_dataloader(self):
102

103
        def sample_data_generator():
104 105
            batch_data = None
            for step, data in enumerate(self.dataset):
106 107 108 109 110 111
                data = flatten(data)
                if batch_data is None:
                    batch_data = [[] for i in range(len(data))]
                for idx in range(len(data)):
                    batch_data[idx].append(data[idx])
                if (step + 1) % self.batch_size == 0:
112 113 114 115 116 117
                    partial_data = []
                    for d in batch_data:
                        array = np.array(d)
                        partial_data.append(
                            np.split(array, self.dp_world_size)[self.dp_rank])
                    yield partial_data[:len(self.feed_list)]
118
                    batch_data = None
119

120 121 122
        def batch_data_generator():
            for data in self.dataset:
                data = flatten(data)
123 124 125 126 127 128 129
                partial_data = []
                for d in data:
                    assert d.shape[0] % self.dp_world_size == 0, \
                        "Please padding dataset with data parallel size"
                    partial_data.append(
                        np.split(d, self.dp_world_size)[self.dp_rank])
                yield partial_data[:len(self.feed_list)]
130 131 132

        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
133
        if self.batch_size is not None:
134 135 136 137
            dataloader.set_batch_generator(sample_data_generator, self.places)
        else:
            dataloader.set_batch_generator(batch_data_generator, self.places)

138
        return dataloader