fetcher.py 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import logging
from ..log_helper import get_logger
17
from collections.abc import Sequence, Mapping
18

19
_WARNING_TO_LOG = True
20

21 22

class _DatasetFetcher(object):
23
    def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
24
        self.dataset = dataset
25
        self.auto_collate_batch = auto_collate_batch
26 27 28 29 30 31 32
        self.collate_fn = collate_fn
        self.drop_last = drop_last

    def fetch(self, batch_indices):
        raise NotImplementedError("'fetch' not implement for class {}".format(
            self.__class__.__name__))

33
    def _log_warning(self):
34 35 36 37 38
        # only log warning on GPU 0 when distributed launch
        from ...distributed import get_world_size, get_rank
        if get_world_size() >= 2 and get_rank() != 0:
            return

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        warn_str = "Detect dataset only contains single fileds, return format " \
                   "changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
                   "a list surround output data(e.g. return [data]), and in " \
                   "Paddle >= 2.1, DataLoader return the single filed directly " \
                   "(e.g. return data). For example, in following code: \n\n"
        warn_str += \
                "import numpy as np\n" \
                "from paddle.io import DataLoader, Dataset\n\n" \
                "class RandomDataset(Dataset):\n" \
                "    def __getitem__(self, idx):\n" \
                "        data = np.random.random((2, 3)).astype('float32')\n\n" \
                "        return data\n\n" \
                "    def __len__(self):\n" \
                "        return 10\n\n" \
                "dataset = RandomDataset()\n" \
                "loader = DataLoader(dataset, batch_size=1)\n" \
                "data = next(loader())\n\n"

        warn_str += "In Paddle <= 2.0, data is in format '[Tensor(shape=(1, 2, 3), " \
                    "dtype=float32)]', and in Paddle >= 2.1, data is in format" \
                    " 'Tensor(shape=(1, 2, 3), dtype=float32)'\n"

        logger = get_logger(
            "DataLoader", logging.INFO, fmt='%(levelname)s: %(message)s')
        logger.warning(warn_str)

65 66

class _IterableDatasetFetcher(_DatasetFetcher):
67
    def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
S
seemingwang 已提交
68 69
        super(_IterableDatasetFetcher, self).__init__(
            dataset, auto_collate_batch, collate_fn, drop_last)
70 71 72 73
        self.dataset_iter = iter(dataset)

    def fetch(self, batch_indices):

74 75 76 77 78 79 80
        if self.auto_collate_batch:
            data = []
            for _ in batch_indices:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
81

82 83 84
            if len(data) == 0 or (self.drop_last and
                                  len(data) < len(batch_indices)):
                raise StopIteration
85 86 87 88

            global _WARNING_TO_LOG
            if not isinstance(data[0], (Sequence, Mapping)) \
                    and _WARNING_TO_LOG:
89
                self._log_warning()
90
                _WARNING_TO_LOG = False
91 92 93 94 95 96
        else:
            data = next(self.dataset_iter)

        if self.collate_fn:
            data = self.collate_fn(data)
        return data
97 98 99


class _MapDatasetFetcher(_DatasetFetcher):
100
    def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
S
seemingwang 已提交
101 102
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch,
                                                 collate_fn, drop_last)
103 104

    def fetch(self, batch_indices):
105 106
        if self.auto_collate_batch:
            data = [self.dataset[idx] for idx in batch_indices]
107

108 109 110
            global _WARNING_TO_LOG
            if not isinstance(data[0], (Sequence, Mapping)) \
                    and _WARNING_TO_LOG:
111
                self._log_warning()
112
                _WARNING_TO_LOG = False
113 114 115 116 117 118
        else:
            data = self.dataset[batch_indices]

        if self.collate_fn:
            data = self.collate_fn(data)
        return data