提交 9029259d 编写于 作者: W Webbley

add stream shuffle

上级 d7d96a89
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""dataloader """dataloader
""" """
import warnings
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
...@@ -64,15 +64,15 @@ class Dataloader(object): ...@@ -64,15 +64,15 @@ class Dataloader(object):
print(batch_data) print(batch_data)
""" """
def __init__( def __init__(self,
self, dataset,
dataset, batch_size=1,
batch_size=1, drop_last=False,
drop_last=False, shuffle=False,
shuffle=False, num_workers=1,
num_workers=1, collate_fn=None,
collate_fn=None, buf_size=1000,
buf_size=1000, ): stream_shuffle_size=0):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
...@@ -81,6 +81,22 @@ class Dataloader(object): ...@@ -81,6 +81,22 @@ class Dataloader(object):
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.buf_size = buf_size self.buf_size = buf_size
self.drop_last = drop_last self.drop_last = drop_last
self.stream_shuffle_size = stream_shuffle_size
if self.shuffle and isinstance(self.dataset, StreamDataset):
warn_msg = "[shuffle] should not be True with StreamDataset. " \
"It will be ignored. " \
"You might want to set [stream_shuffle_size] with StreamDataset."
warnings.warn(warn_msg)
if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size:
raise ValueError("stream_shuffle_size must be larger than batch_size," \
"but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \
% (self.stream_shuffle_size, self.batch_size))
if self.num_workers < 1:
raise ValueError("num_workers(default: 1) should be larger than 0, " \
"but got [num_workers=%s] < 1." % self.num_workers)
def __len__(self): def __len__(self):
if not isinstance(self.dataset, StreamDataset): if not isinstance(self.dataset, StreamDataset):
...@@ -129,6 +145,8 @@ class _DataLoaderIter(object): ...@@ -129,6 +145,8 @@ class _DataLoaderIter(object):
self.collate_fn = dataloader.collate_fn self.collate_fn = dataloader.collate_fn
self.num_workers = dataloader.num_workers self.num_workers = dataloader.num_workers
self.drop_last = dataloader.drop_last self.drop_last = dataloader.drop_last
self.batch_size = dataloader.batch_size
self.stream_shuffle_size = dataloader.stream_shuffle_size
self.fid = fid self.fid = fid
self.count = 0 self.count = 0
...@@ -167,17 +185,63 @@ class _DataLoaderIter(object): ...@@ -167,17 +185,63 @@ class _DataLoaderIter(object):
# make sure do not repeat in multiprocessing # make sure do not repeat in multiprocessing
self.count += 1 self.count += 1
# if self.count % self.num_workers != self.fid:
# continue
if self.collate_fn is not None: if self.collate_fn is not None:
yield self.collate_fn(batch_data) yield self.collate_fn(batch_data)
else: else:
yield batch_data yield batch_data
def _stream_shuffle_data_generator(self):
def _stream_shuffle_index_generator():
shuffle_size = [i for i in range(self.stream_shuffle_size)]
while True:
yield shuffle_size
def _data_generator():
dataset = iter(self.dataset)
for shuffle_size in _stream_shuffle_index_generator():
shuffle_size_data = []
for idx in shuffle_size:
try:
shuffle_size_data.append(next(dataset))
except StopIteration:
break
if len(shuffle_size_data) == 0:
break
yield shuffle_size_data
def _batch_data_generator():
batch_data = []
for shuffle_size_data in _data_generator():
np.random.shuffle(shuffle_size_data)
for d in shuffle_size_data:
batch_data.append(d)
if len(batch_data) == self.batch_size:
yield batch_data
batch_data = []
if not self.drop_last and len(batch_data) > 0:
yield batch_data
self._worker_info = WorkerInfo(
num_workers=self.num_workers, fid=self.fid)
self.dataset._set_worker_info(self._worker_info)
for batch_data in _batch_data_generator():
if self.collate_fn is not None:
yield self.collate_fn(batch_data)
else:
yield batch_data
def __iter__(self): def __iter__(self):
if isinstance(self.dataset, StreamDataset): if isinstance(self.dataset, StreamDataset):
data_generator = self._streamdata_generator if self.stream_shuffle_size > 0:
data_generator = self._stream_shuffle_data_generator
else:
data_generator = self._streamdata_generator
else: else:
data_generator = self._data_generator data_generator = self._data_generator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册