builder.py 3.5 KB
Newer Older
L
LielinJiang 已提交
1 2 3
import paddle
import numbers
import numpy as np
L
LielinJiang 已提交
4
from multiprocessing import Manager
L
LielinJiang 已提交
5 6 7 8 9 10 11 12 13
from paddle.imperative import ParallelEnv

from paddle.incubate.hapi.distributed import DistributedBatchSampler
from ..utils.registry import Registry


DATASETS = Registry("DATASETS")


L
LielinJiang 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
class DictDataset(paddle.io.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.tensor_keys_set = set()
        self.non_tensor_keys_set = set()
        self.non_tensor_dict = Manager().dict()

        single_item = dataset[0]
        self.keys = single_item.keys()
        
        for k, v in single_item.items():
            if not isinstance(v, (numbers.Number, np.ndarray)):
                self.non_tensor_dict.update({k: {}})
                self.non_tensor_keys_set.add(k)
            else:
                self.tensor_keys_set.add(k)

    def __getitem__(self, index):

        ori_map = self.dataset[index]
        
        tmp_list = []
        
        for k, v in ori_map.items():
            if isinstance(v, (numbers.Number, np.ndarray)):
                tmp_list.append(v)
            else:
                tmp_dict = self.non_tensor_dict[k]
                tmp_dict.update({index: v})
                self.non_tensor_dict[k] = tmp_dict
L
LielinJiang 已提交
44

L
LielinJiang 已提交
45 46
        tmp_list.append(index)
        return tuple(tmp_list)
L
LielinJiang 已提交
47

L
LielinJiang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61
    def __len__(self):
        return len(self.dataset)

    def reset(self):
        for k in self.non_tensor_keys_set:
            self.non_tensor_dict[k] = {}


class DictDataLoader():
    def __init__(self, dataset, batch_size, is_train, num_workers=0):

        self.dataset = DictDataset(dataset)

        place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
L
LielinJiang 已提交
62 63
                    if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)

L
LielinJiang 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        sampler = DistributedBatchSampler(
                                        self.dataset,
                                        batch_size=batch_size,
                                        shuffle=True if is_train else False,
                                        drop_last=True if is_train else False)

        self.dataloader = paddle.io.DataLoader(
                                        self.dataset,
                                        batch_sampler=sampler,
                                        places=place,
                                        num_workers=num_workers)

        self.batch_size = batch_size

    def __iter__(self):

        self.dataset.reset()

        for i, data in enumerate(self.dataloader):
            return_dict = {}
            j = 0
            for k in self.dataset.keys:
                if k in self.dataset.tensor_keys_set:
                    return_dict[k] = data[j] if isinstance(data, (list, tuple)) else data
                    j += 1
                else:
                    return_dict[k] = self.get_items_by_indexs(k, data[-1])
            yield return_dict

    def __len__(self):
        return len(self.dataloader)

    def get_items_by_indexs(self, key, indexs):
        if isinstance(indexs, paddle.Variable):
            indexs = indexs.numpy()
        current_items = []
        items = getattr(self.dataset, key)

        for index in indexs:
            current_items.append(items[index])

        return current_items



def build_dataloader(cfg, is_train=True):
    dataset = DATASETS.get(cfg.name)(cfg)
    
    batch_size = cfg.get('batch_size', 1)
    num_workers = cfg.get('num_workers', 0)
L
LielinJiang 已提交
114

L
LielinJiang 已提交
115
    dataloader = DictDataLoader(dataset, batch_size, is_train)
L
LielinJiang 已提交
116 117

    return dataloader