提交 9029259d 编写于 作者: W Webbley

add stream shuffle

上级 d7d96a89
......@@ -14,7 +14,7 @@
# limitations under the License.
"""dataloader
"""
import warnings
import numpy as np
from collections import namedtuple
......@@ -64,15 +64,15 @@ class Dataloader(object):
print(batch_data)
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
shuffle=False,
num_workers=1,
collate_fn=None,
buf_size=1000, ):
def __init__(self,
dataset,
batch_size=1,
drop_last=False,
shuffle=False,
num_workers=1,
collate_fn=None,
buf_size=1000,
stream_shuffle_size=0):
self.dataset = dataset
self.batch_size = batch_size
......@@ -81,6 +81,22 @@ class Dataloader(object):
self.collate_fn = collate_fn
self.buf_size = buf_size
self.drop_last = drop_last
self.stream_shuffle_size = stream_shuffle_size
if self.shuffle and isinstance(self.dataset, StreamDataset):
warn_msg = "[shuffle] should not be True with StreamDataset. " \
"It will be ignored. " \
"You might want to set [stream_shuffle_size] with StreamDataset."
warnings.warn(warn_msg)
if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size:
raise ValueError("stream_shuffle_size must be larger than batch_size," \
"but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \
% (self.stream_shuffle_size, self.batch_size))
if self.num_workers < 1:
raise ValueError("num_workers(default: 1) should be larger than 0, " \
"but got [num_workers=%s] < 1." % self.num_workers)
def __len__(self):
if not isinstance(self.dataset, StreamDataset):
......@@ -129,6 +145,8 @@ class _DataLoaderIter(object):
self.collate_fn = dataloader.collate_fn
self.num_workers = dataloader.num_workers
self.drop_last = dataloader.drop_last
self.batch_size = dataloader.batch_size
self.stream_shuffle_size = dataloader.stream_shuffle_size
self.fid = fid
self.count = 0
......@@ -167,17 +185,63 @@ class _DataLoaderIter(object):
# make sure do not repeat in multiprocessing
self.count += 1
# if self.count % self.num_workers != self.fid:
# continue
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def _stream_shuffle_data_generator(self):
def _stream_shuffle_index_generator():
shuffle_size = [i for i in range(self.stream_shuffle_size)]
while True:
yield shuffle_size
def _data_generator():
dataset = iter(self.dataset)
for shuffle_size in _stream_shuffle_index_generator():
shuffle_size_data = []
for idx in shuffle_size:
try:
shuffle_size_data.append(next(dataset))
except StopIteration:
break
if len(shuffle_size_data) == 0:
break
yield shuffle_size_data
def _batch_data_generator():
batch_data = []
for shuffle_size_data in _data_generator():
np.random.shuffle(shuffle_size_data)
for d in shuffle_size_data:
batch_data.append(d)
if len(batch_data) == self.batch_size:
yield batch_data
batch_data = []
if not self.drop_last and len(batch_data) > 0:
yield batch_data
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
for batch_data in _batch_data_generator():
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def __iter__(self):
if isinstance(self.dataset, StreamDataset):
data_generator = self._streamdata_generator
if self.stream_shuffle_size > 0:
data_generator = self._stream_shuffle_data_generator
else:
data_generator = self._streamdata_generator
else:
data_generator = self._data_generator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册