builder.py 3.4 KB
Newer Older
1
import time
L
LielinJiang 已提交
2 3 4
import paddle
import numbers
import numpy as np
L
LielinJiang 已提交
5
from multiprocessing import Manager
L
LielinJiang 已提交
6 7 8 9 10 11 12 13 14
from paddle.imperative import ParallelEnv

from paddle.incubate.hapi.distributed import DistributedBatchSampler
from ..utils.registry import Registry


DATASETS = Registry("DATASETS")


L
LielinJiang 已提交
15 16 17 18 19 20 21 22 23 24 25 26
class DictDataset(paddle.io.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.tensor_keys_set = set()
        self.non_tensor_keys_set = set()
        self.non_tensor_dict = Manager().dict()

        single_item = dataset[0]
        self.keys = single_item.keys()
        
        for k, v in single_item.items():
            if not isinstance(v, (numbers.Number, np.ndarray)):
27
                setattr(self, k, Manager().dict())
L
LielinJiang 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41
                self.non_tensor_keys_set.add(k)
            else:
                self.tensor_keys_set.add(k)

    def __getitem__(self, index):

        ori_map = self.dataset[index]
        
        tmp_list = []
        
        for k, v in ori_map.items():
            if isinstance(v, (numbers.Number, np.ndarray)):
                tmp_list.append(v)
            else:
42
                getattr(self, k).update({index: v})
L
LielinJiang 已提交
43

L
LielinJiang 已提交
44 45
        tmp_list.append(index)
        return tuple(tmp_list)
L
LielinJiang 已提交
46

L
LielinJiang 已提交
47 48 49 50 51
    def __len__(self):
        return len(self.dataset)

    def reset(self):
        for k in self.non_tensor_keys_set:
52
            setattr(self, k, Manager().dict())
L
LielinJiang 已提交
53 54 55


class DictDataLoader():
56
    def __init__(self, dataset, batch_size, is_train, num_workers=4):
L
LielinJiang 已提交
57 58 59 60

        self.dataset = DictDataset(dataset)

        place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
L
LielinJiang 已提交
61 62
                    if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)

L
LielinJiang 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        sampler = DistributedBatchSampler(
                                        self.dataset,
                                        batch_size=batch_size,
                                        shuffle=True if is_train else False,
                                        drop_last=True if is_train else False)

        self.dataloader = paddle.io.DataLoader(
                                        self.dataset,
                                        batch_sampler=sampler,
                                        places=place,
                                        num_workers=num_workers)

        self.batch_size = batch_size

    def __iter__(self):

        self.dataset.reset()

        for i, data in enumerate(self.dataloader):
            return_dict = {}
            j = 0
            for k in self.dataset.keys:
                if k in self.dataset.tensor_keys_set:
                    return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data
                    j += 1
                else:
                    return_dict[k] = self.get_items_by_indexs(k, data[-1])
            yield return_dict

    def __len__(self):
        return len(self.dataloader)

    def get_items_by_indexs(self, key, indexs):
        if isinstance(indexs, paddle.Variable):
            indexs = indexs.numpy()
        current_items = []
99
        items = getattr(self.dataset, key)
L
LielinJiang 已提交
100 101 102 103 104 105 106

        for index in indexs:
            current_items.append(items[index])

        return current_items


107

L
LielinJiang 已提交
108 109 110 111 112
def build_dataloader(cfg, is_train=True):
    dataset = DATASETS.get(cfg.name)(cfg)
    
    batch_size = cfg.get('batch_size', 1)
    num_workers = cfg.get('num_workers', 0)
L
LielinJiang 已提交
113

114
    dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
L
LielinJiang 已提交
115 116

    return dataloader