未验证 提交 6094e441 编写于 作者: L LielinJiang 提交者: GitHub

rm dict dataloader (#312)

上级 9f77834d
...@@ -16,125 +16,14 @@ import time ...@@ -16,125 +16,14 @@ import time
import paddle import paddle
import numbers import numbers
import numpy as np import numpy as np
from multiprocessing import Manager
from paddle.distributed import ParallelEnv
from paddle.distributed import ParallelEnv
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ..utils.registry import Registry from ..utils.registry import Registry
DATASETS = Registry("DATASETS") DATASETS = Registry("DATASETS")
class DictDataset(paddle.io.Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.tensor_keys_set = set()
self.non_tensor_keys_set = set()
self.non_tensor_dict = Manager().dict()
single_item = dataset[0]
self.keys = single_item.keys()
for k, v in single_item.items():
if not isinstance(v, (numbers.Number, np.ndarray)):
setattr(self, k, Manager().dict())
self.non_tensor_keys_set.add(k)
else:
self.tensor_keys_set.add(k)
def __getitem__(self, index):
ori_map = self.dataset[index]
tmp_list = []
for k, v in ori_map.items():
if isinstance(v, (numbers.Number, np.ndarray)):
tmp_list.append(v)
else:
getattr(self, k).update({index: v})
tmp_list.append(index)
return tuple(tmp_list)
def __len__(self):
return len(self.dataset)
def reset(self):
for k in self.non_tensor_keys_set:
setattr(self, k, Manager().dict())
class DictDataLoader():
def __init__(self,
dataset,
batch_size,
is_train,
num_workers=4,
use_shared_memory=True,
distributed=True):
self.dataset = DictDataset(dataset)
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
if distributed:
sampler = DistributedBatchSampler(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers,
use_shared_memory=use_shared_memory)
else:
self.dataloader = paddle.io.DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False,
places=place,
use_shared_memory=False,
num_workers=num_workers)
self.batch_size = batch_size
def __iter__(self):
self.dataset.reset()
for i, data in enumerate(self.dataloader):
return_dict = {}
j = 0
for k in self.dataset.keys:
if k in self.dataset.tensor_keys_set:
return_dict[k] = data[j] if isinstance(data,
(list,
tuple)) else data
j += 1
else:
return_dict[k] = self.get_items_by_indexs(k, data[-1])
yield return_dict
def __len__(self):
return len(self.dataloader)
def get_items_by_indexs(self, key, indexs):
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_items = []
items = getattr(self.dataset, key)
for index in indexs:
current_items.append(items[index])
return current_items
def build_dataloader(cfg, is_train=True, distributed=True): def build_dataloader(cfg, is_train=True, distributed=True):
cfg_ = cfg.copy() cfg_ = cfg.copy()
...@@ -145,11 +34,27 @@ def build_dataloader(cfg, is_train=True, distributed=True): ...@@ -145,11 +34,27 @@ def build_dataloader(cfg, is_train=True, distributed=True):
name = cfg_.pop('name') name = cfg_.pop('name')
dataset = DATASETS.get(name)(**cfg_) dataset = DATASETS.get(name)(**cfg_)
dataloader = DictDataLoader(dataset, place = paddle.CUDAPlace(ParallelEnv().dev_id) \
batch_size, if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
is_train,
num_workers, if distributed:
use_shared_memory=use_shared_memory, sampler = DistributedBatchSampler(dataset,
distributed=distributed) batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False)
dataloader = paddle.io.DataLoader(dataset,
batch_sampler=sampler,
places=place,
num_workers=num_workers,
use_shared_memory=use_shared_memory)
else:
dataloader = paddle.io.DataLoader(dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
drop_last=True if is_train else False,
places=place,
use_shared_memory=False,
num_workers=num_workers)
return dataloader return dataloader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册