diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index da582bb6d76efc5829003740af67abd69f7e91a9..ae2369ab283cbde23984077a682d8c504859e724 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -16,125 +16,14 @@ import time import paddle import numbers import numpy as np -from multiprocessing import Manager -from paddle.distributed import ParallelEnv +from paddle.distributed import ParallelEnv from paddle.io import DistributedBatchSampler from ..utils.registry import Registry DATASETS = Registry("DATASETS") -class DictDataset(paddle.io.Dataset): - def __init__(self, dataset): - self.dataset = dataset - self.tensor_keys_set = set() - self.non_tensor_keys_set = set() - self.non_tensor_dict = Manager().dict() - single_item = dataset[0] - self.keys = single_item.keys() - - for k, v in single_item.items(): - if not isinstance(v, (numbers.Number, np.ndarray)): - setattr(self, k, Manager().dict()) - self.non_tensor_keys_set.add(k) - else: - self.tensor_keys_set.add(k) - - def __getitem__(self, index): - - ori_map = self.dataset[index] - - tmp_list = [] - - for k, v in ori_map.items(): - if isinstance(v, (numbers.Number, np.ndarray)): - tmp_list.append(v) - else: - getattr(self, k).update({index: v}) - - tmp_list.append(index) - return tuple(tmp_list) - - def __len__(self): - return len(self.dataset) - - def reset(self): - for k in self.non_tensor_keys_set: - setattr(self, k, Manager().dict()) - - -class DictDataLoader(): - def __init__(self, - dataset, - batch_size, - is_train, - num_workers=4, - use_shared_memory=True, - distributed=True): - - self.dataset = DictDataset(dataset) - - place = paddle.CUDAPlace(ParallelEnv().dev_id) \ - if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) - - if distributed: - sampler = DistributedBatchSampler( - self.dataset, - batch_size=batch_size, - shuffle=True if is_train else False, - drop_last=True if is_train else False) - - self.dataloader = paddle.io.DataLoader( - self.dataset, - batch_sampler=sampler, - places=place, - num_workers=num_workers, - use_shared_memory=use_shared_memory) - else: - self.dataloader = paddle.io.DataLoader( - self.dataset, - batch_size=batch_size, - shuffle=True if is_train else False, - drop_last=True if is_train else False, - places=place, - use_shared_memory=False, - num_workers=num_workers) - - self.batch_size = batch_size - - def __iter__(self): - - self.dataset.reset() - - for i, data in enumerate(self.dataloader): - return_dict = {} - j = 0 - for k in self.dataset.keys: - if k in self.dataset.tensor_keys_set: - return_dict[k] = data[j] if isinstance(data, - (list, - tuple)) else data - j += 1 - else: - return_dict[k] = self.get_items_by_indexs(k, data[-1]) - yield return_dict - - def __len__(self): - return len(self.dataloader) - - def get_items_by_indexs(self, key, indexs): - if isinstance(indexs, paddle.Tensor): - indexs = indexs.numpy() - current_items = [] - items = getattr(self.dataset, key) - - for index in indexs: - current_items.append(items[index]) - - return current_items - - def build_dataloader(cfg, is_train=True, distributed=True): cfg_ = cfg.copy() @@ -145,11 +34,27 @@ def build_dataloader(cfg, is_train=True, distributed=True): name = cfg_.pop('name') dataset = DATASETS.get(name)(**cfg_) - dataloader = DictDataLoader(dataset, - batch_size, - is_train, - num_workers, - use_shared_memory=use_shared_memory, - distributed=distributed) + place = paddle.CUDAPlace(ParallelEnv().dev_id) \ + if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) + + if distributed: + sampler = DistributedBatchSampler(dataset, + batch_size=batch_size, + shuffle=True if is_train else False, + drop_last=True if is_train else False) + + dataloader = paddle.io.DataLoader(dataset, + batch_sampler=sampler, + places=place, + num_workers=num_workers, + use_shared_memory=use_shared_memory) + else: + dataloader = paddle.io.DataLoader(dataset, + batch_size=batch_size, + shuffle=True if is_train else False, + drop_last=True if is_train else False, + places=place, + use_shared_memory=False, + num_workers=num_workers) return dataloader