From 9029259d807dd1c5d97205f0d32b6dd669c45eb0 Mon Sep 17 00:00:00 2001 From: Webbley Date: Tue, 22 Sep 2020 16:44:02 +0800 Subject: [PATCH] add stream shuffle --- pgl/utils/data/dataloader.py | 90 ++++++++++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 13 deletions(-) diff --git a/pgl/utils/data/dataloader.py b/pgl/utils/data/dataloader.py index 71528bf..4100bc8 100644 --- a/pgl/utils/data/dataloader.py +++ b/pgl/utils/data/dataloader.py @@ -14,7 +14,7 @@ # limitations under the License. """dataloader """ - +import warnings import numpy as np from collections import namedtuple @@ -64,15 +64,15 @@ class Dataloader(object): print(batch_data) """ - def __init__( - self, - dataset, - batch_size=1, - drop_last=False, - shuffle=False, - num_workers=1, - collate_fn=None, - buf_size=1000, ): + def __init__(self, + dataset, + batch_size=1, + drop_last=False, + shuffle=False, + num_workers=1, + collate_fn=None, + buf_size=1000, + stream_shuffle_size=0): self.dataset = dataset self.batch_size = batch_size @@ -81,6 +81,22 @@ class Dataloader(object): self.collate_fn = collate_fn self.buf_size = buf_size self.drop_last = drop_last + self.stream_shuffle_size = stream_shuffle_size + + if self.shuffle and isinstance(self.dataset, StreamDataset): + warn_msg = "[shuffle] should not be True with StreamDataset. " \ + "It will be ignored. " \ + "You might want to set [stream_shuffle_size] with StreamDataset." + warnings.warn(warn_msg) + + if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size: + raise ValueError("stream_shuffle_size must be larger than batch_size," \ + "but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \ + % (self.stream_shuffle_size, self.batch_size)) + + if self.num_workers < 1: + raise ValueError("num_workers(default: 1) should be larger than 0, " \ + "but got [num_workers=%s] < 1." % self.num_workers) def __len__(self): if not isinstance(self.dataset, StreamDataset): @@ -129,6 +145,8 @@ class _DataLoaderIter(object): self.collate_fn = dataloader.collate_fn self.num_workers = dataloader.num_workers self.drop_last = dataloader.drop_last + self.batch_size = dataloader.batch_size + self.stream_shuffle_size = dataloader.stream_shuffle_size self.fid = fid self.count = 0 @@ -167,17 +185,63 @@ class _DataLoaderIter(object): # make sure do not repeat in multiprocessing self.count += 1 - # if self.count % self.num_workers != self.fid: - # continue if self.collate_fn is not None: yield self.collate_fn(batch_data) else: yield batch_data + def _stream_shuffle_data_generator(self): + def _stream_shuffle_index_generator(): + shuffle_size = [i for i in range(self.stream_shuffle_size)] + while True: + yield shuffle_size + + def _data_generator(): + dataset = iter(self.dataset) + for shuffle_size in _stream_shuffle_index_generator(): + shuffle_size_data = [] + for idx in shuffle_size: + try: + shuffle_size_data.append(next(dataset)) + except StopIteration: + break + + if len(shuffle_size_data) == 0: + break + + yield shuffle_size_data + + def _batch_data_generator(): + batch_data = [] + for shuffle_size_data in _data_generator(): + np.random.shuffle(shuffle_size_data) + + for d in shuffle_size_data: + batch_data.append(d) + if len(batch_data) == self.batch_size: + yield batch_data + batch_data = [] + + if not self.drop_last and len(batch_data) > 0: + yield batch_data + + self._worker_info = WorkerInfo( + num_workers=self.num_workers, fid=self.fid) + self.dataset._set_worker_info(self._worker_info) + + for batch_data in _batch_data_generator(): + if self.collate_fn is not None: + yield self.collate_fn(batch_data) + else: + yield batch_data + def __iter__(self): if isinstance(self.dataset, StreamDataset): - data_generator = self._streamdata_generator + if self.stream_shuffle_size > 0: + data_generator = self._stream_shuffle_data_generator + else: + data_generator = self._streamdata_generator else: data_generator = self._data_generator -- GitLab