builder.py 4.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import time
L
LielinJiang 已提交
16 17 18
import paddle
import numbers
import numpy as np
L
LielinJiang 已提交
19
from multiprocessing import Manager
L
LielinJiang 已提交
20
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
21

L
LielinJiang 已提交
22
from paddle.io import DistributedBatchSampler
L
LielinJiang 已提交
23 24 25 26 27
from ..utils.registry import Registry

DATASETS = Registry("DATASETS")


L
LielinJiang 已提交
28 29 30 31 32 33 34 35
class DictDataset(paddle.io.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.tensor_keys_set = set()
        self.non_tensor_keys_set = set()
        self.non_tensor_dict = Manager().dict()
        single_item = dataset[0]
        self.keys = single_item.keys()
L
LielinJiang 已提交
36

L
LielinJiang 已提交
37 38
        for k, v in single_item.items():
            if not isinstance(v, (numbers.Number, np.ndarray)):
39
                setattr(self, k, Manager().dict())
L
LielinJiang 已提交
40 41 42 43 44 45 46
                self.non_tensor_keys_set.add(k)
            else:
                self.tensor_keys_set.add(k)

    def __getitem__(self, index):

        ori_map = self.dataset[index]
L
LielinJiang 已提交
47

L
LielinJiang 已提交
48
        tmp_list = []
L
LielinJiang 已提交
49

L
LielinJiang 已提交
50 51 52 53
        for k, v in ori_map.items():
            if isinstance(v, (numbers.Number, np.ndarray)):
                tmp_list.append(v)
            else:
54
                getattr(self, k).update({index: v})
L
LielinJiang 已提交
55

L
LielinJiang 已提交
56 57
        tmp_list.append(index)
        return tuple(tmp_list)
L
LielinJiang 已提交
58

L
LielinJiang 已提交
59 60 61 62 63
    def __len__(self):
        return len(self.dataset)

    def reset(self):
        for k in self.non_tensor_keys_set:
64
            setattr(self, k, Manager().dict())
L
LielinJiang 已提交
65 66 67


class DictDataLoader():
68 69 70 71 72
    def __init__(self,
                 dataset,
                 batch_size,
                 is_train,
                 num_workers=4,
L
lijianshe02 已提交
73
                 use_shared_memory=True,
74
                 distributed=True):
L
LielinJiang 已提交
75 76 77

        self.dataset = DictDataset(dataset)

L
fix nan  
LielinJiang 已提交
78 79
        place = paddle.CUDAPlace(ParallelEnv().dev_id) \
                    if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
L
LielinJiang 已提交
80

81 82 83 84 85 86 87
        if distributed:
            sampler = DistributedBatchSampler(
                self.dataset,
                batch_size=batch_size,
                shuffle=True if is_train else False,
                drop_last=True if is_train else False)

L
lijianshe02 已提交
88 89 90 91 92 93
            self.dataloader = paddle.io.DataLoader(
                self.dataset,
                batch_sampler=sampler,
                places=place,
                num_workers=num_workers,
                use_shared_memory=use_shared_memory)
94 95 96 97 98 99 100
        else:
            self.dataloader = paddle.io.DataLoader(
                self.dataset,
                batch_size=batch_size,
                shuffle=True if is_train else False,
                drop_last=True if is_train else False,
                places=place,
L
lijianshe02 已提交
101
                use_shared_memory=False,
102
                num_workers=num_workers)
L
LielinJiang 已提交
103 104 105 106 107 108 109 110 111 112 113 114

        self.batch_size = batch_size

    def __iter__(self):

        self.dataset.reset()

        for i, data in enumerate(self.dataloader):
            return_dict = {}
            j = 0
            for k in self.dataset.keys:
                if k in self.dataset.tensor_keys_set:
L
LielinJiang 已提交
115 116 117
                    return_dict[k] = data[j] if isinstance(data,
                                                           (list,
                                                            tuple)) else data
L
LielinJiang 已提交
118 119 120 121 122 123 124 125 126
                    j += 1
                else:
                    return_dict[k] = self.get_items_by_indexs(k, data[-1])
            yield return_dict

    def __len__(self):
        return len(self.dataloader)

    def get_items_by_indexs(self, key, indexs):
littletomatodonkey's avatar
littletomatodonkey 已提交
127
        if isinstance(indexs, paddle.Tensor):
L
LielinJiang 已提交
128 129
            indexs = indexs.numpy()
        current_items = []
130
        items = getattr(self.dataset, key)
L
LielinJiang 已提交
131 132 133 134 135 136 137

        for index in indexs:
            current_items.append(items[index])

        return current_items


138 139 140 141 142
def build_dataloader(cfg, is_train=True, distributed=True):
    cfg_ = cfg.copy()

    batch_size = cfg_.pop('batch_size', 1)
    num_workers = cfg_.pop('num_workers', 0)
L
lijianshe02 已提交
143
    use_shared_memory = cfg_.pop('use_shared_memory', True)
144 145

    name = cfg_.pop('name')
L
LielinJiang 已提交
146

147 148 149 150 151
    dataset = DATASETS.get(name)(**cfg_)
    dataloader = DictDataLoader(dataset,
                                batch_size,
                                is_train,
                                num_workers,
L
lijianshe02 已提交
152
                                use_shared_memory=use_shared_memory,
153
                                distributed=distributed)
L
LielinJiang 已提交
154

L
LielinJiang 已提交
155
    return dataloader