dist_loader.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import abc
import numpy as np
import paddle
18
from .utils import to_list
19
from paddle.fluid.layers.utils import flatten
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
from paddle.io import DataLoader, DistributedBatchSampler


class DistributedDataLoader(metaclass=abc.ABCMeta):
    def __init__(self,
                 dataset,
                 batch_size=1,
                 epochs=1,
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs
        self.data_parallel_world_size = data_parallel_world_size
        self.data_parallel_rank = data_parallel_rank
        self.drop_lost = drop_last
37
        if data_parallel_world_size is not None and batch_size is not None:
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
            assert batch_size % data_parallel_world_size == 0

    @abc.abstractmethod
    def __iter__(self):
        raise NotImplementedError

    @abc.abstractmethod
    def __next__(self):
        raise NotImplementedError


class NonIterableGeneratorLoader(DistributedDataLoader):
    def __init__(self,
                 dataset,
                 feed_list,
                 places,
                 batch_size=1,
                 epochs=1,
56
                 steps_per_epoch=None,
57 58
                 data_parallel_world_size=None,
                 data_parallel_rank=None,
59
                 drop_last=False):
60 61 62
        self.feed_list = feed_list
        self.places = places
        self.steps_per_epoch = steps_per_epoch
63 64
        self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
        self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
65

66 67 68 69
        super(NonIterableGeneratorLoader, self).__init__(
            dataset, batch_size, epochs, data_parallel_world_size,
            data_parallel_rank, drop_last)
        self._inner_dataloader = self._create_inner_dataloader()
70
        self._steps = self._infer_steps()
71 72 73 74 75 76 77

    def __iter__(self):
        self._cur_step = 0
        self._inner_dataloader.start()
        return self

    def __next__(self):
78
        if self._cur_step < self._steps:
79 80 81 82 83
            self._cur_step += 1
        else:
            self._inner_dataloader.reset()
            raise StopIteration

84 85 86 87
    def _infer_steps(self):
        if self.steps_per_epoch is not None:
            return self.steps_per_epoch
        try:
88 89 90 91
            if self.batch_size is None:
                steps_per_epoch = len(self.dataset)
            else:
                steps_per_epoch = len(self.dataset) // self.batch_size
92 93 94 95 96 97
        except:
            raise ValueError(
                "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
            )
        return steps_per_epoch

98
    def _create_inner_dataloader(self):
99
        def sample_data_generator():
100 101
            batch_data = None
            for step, data in enumerate(self.dataset):
102 103 104 105 106 107
                data = flatten(data)
                if batch_data is None:
                    batch_data = [[] for i in range(len(data))]
                for idx in range(len(data)):
                    batch_data[idx].append(data[idx])
                if (step + 1) % self.batch_size == 0:
108 109 110 111 112 113
                    partial_data = []
                    for d in batch_data:
                        array = np.array(d)
                        partial_data.append(
                            np.split(array, self.dp_world_size)[self.dp_rank])
                    yield partial_data[:len(self.feed_list)]
114
                    batch_data = None
115

116 117 118
        def batch_data_generator():
            for data in self.dataset:
                data = flatten(data)
119 120 121 122 123 124 125
                partial_data = []
                for d in data:
                    assert d.shape[0] % self.dp_world_size == 0, \
                        "Please padding dataset with data parallel size"
                    partial_data.append(
                        np.split(d, self.dp_world_size)[self.dp_rank])
                yield partial_data[:len(self.feed_list)]
126 127 128

        dataloader = paddle.fluid.io.DataLoader.from_generator(
            feed_list=self.feed_list, capacity=70, iterable=False)
129
        if self.batch_size is not None:
130 131 132 133
            dataloader.set_batch_generator(sample_data_generator, self.places)
        else:
            dataloader.set_batch_generator(batch_data_generator, self.places)

134
        return dataloader