builder.py 3.4 KB
Newer Older
1
import time
L
LielinJiang 已提交
2 3 4
import paddle
import numbers
import numpy as np
L
LielinJiang 已提交
5
from multiprocessing import Manager
L
LielinJiang 已提交
6
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
7

L
LielinJiang 已提交
8
from paddle.io import DistributedBatchSampler
L
LielinJiang 已提交
9 10 11 12 13
from ..utils.registry import Registry

DATASETS = Registry("DATASETS")


L
LielinJiang 已提交
14 15 16 17 18 19 20 21 22
class DictDataset(paddle.io.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.tensor_keys_set = set()
        self.non_tensor_keys_set = set()
        self.non_tensor_dict = Manager().dict()

        single_item = dataset[0]
        self.keys = single_item.keys()
L
LielinJiang 已提交
23

L
LielinJiang 已提交
24 25
        for k, v in single_item.items():
            if not isinstance(v, (numbers.Number, np.ndarray)):
26
                setattr(self, k, Manager().dict())
L
LielinJiang 已提交
27 28 29 30 31 32 33
                self.non_tensor_keys_set.add(k)
            else:
                self.tensor_keys_set.add(k)

    def __getitem__(self, index):

        ori_map = self.dataset[index]
L
LielinJiang 已提交
34

L
LielinJiang 已提交
35
        tmp_list = []
L
LielinJiang 已提交
36

L
LielinJiang 已提交
37 38 39 40
        for k, v in ori_map.items():
            if isinstance(v, (numbers.Number, np.ndarray)):
                tmp_list.append(v)
            else:
41
                getattr(self, k).update({index: v})
L
LielinJiang 已提交
42

L
LielinJiang 已提交
43 44
        tmp_list.append(index)
        return tuple(tmp_list)
L
LielinJiang 已提交
45

L
LielinJiang 已提交
46 47 48 49 50
    def __len__(self):
        return len(self.dataset)

    def reset(self):
        for k in self.non_tensor_keys_set:
51
            setattr(self, k, Manager().dict())
L
LielinJiang 已提交
52 53 54


class DictDataLoader():
55
    def __init__(self, dataset, batch_size, is_train, num_workers=4):
L
LielinJiang 已提交
56 57 58 59

        self.dataset = DictDataset(dataset)

        place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
L
LielinJiang 已提交
60 61
                    if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)

L
LielinJiang 已提交
62 63 64 65
        sampler = DistributedBatchSampler(self.dataset,
                                          batch_size=batch_size,
                                          shuffle=True if is_train else False,
                                          drop_last=True if is_train else False)
L
LielinJiang 已提交
66

L
LielinJiang 已提交
67 68 69 70
        self.dataloader = paddle.io.DataLoader(self.dataset,
                                               batch_sampler=sampler,
                                               places=place,
                                               num_workers=num_workers)
L
LielinJiang 已提交
71 72 73 74 75 76 77 78 79 80 81 82

        self.batch_size = batch_size

    def __iter__(self):

        self.dataset.reset()

        for i, data in enumerate(self.dataloader):
            return_dict = {}
            j = 0
            for k in self.dataset.keys:
                if k in self.dataset.tensor_keys_set:
L
LielinJiang 已提交
83 84 85
                    return_dict[k] = data[j] if isinstance(data,
                                                           (list,
                                                            tuple)) else data
L
LielinJiang 已提交
86 87 88 89 90 91 92 93 94 95 96 97
                    j += 1
                else:
                    return_dict[k] = self.get_items_by_indexs(k, data[-1])
            yield return_dict

    def __len__(self):
        return len(self.dataloader)

    def get_items_by_indexs(self, key, indexs):
        if isinstance(indexs, paddle.Variable):
            indexs = indexs.numpy()
        current_items = []
98
        items = getattr(self.dataset, key)
L
LielinJiang 已提交
99 100 101 102 103 104 105 106 107

        for index in indexs:
            current_items.append(items[index])

        return current_items


def build_dataloader(cfg, is_train=True):
    dataset = DATASETS.get(cfg.name)(cfg)
L
LielinJiang 已提交
108

L
LielinJiang 已提交
109 110
    batch_size = cfg.get('batch_size', 1)
    num_workers = cfg.get('num_workers', 0)
L
LielinJiang 已提交
111

112
    dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
L
LielinJiang 已提交
113

L
LielinJiang 已提交
114
    return dataloader