提交 e082e277 编写于 作者: M Megvii Engine Team

refactor(mge/data): Refactor megeninge.data.dataset

GitOrigin-RevId: 1d9c61ce70059de9e7e0f804a67f48e54d3891a6
上级 f04e0d77
......@@ -10,6 +10,7 @@ from .collator import Collator
from .dataloader import DataLoader
from .sampler import (
Infinite,
MapSampler,
RandomSampler,
ReplacementSampler,
Sampler,
......
......@@ -20,7 +20,7 @@ from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform
logger = get_logger(__name__)
......@@ -88,17 +88,24 @@ class DataLoader:
self.divide = divide
if sampler is None:
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if isinstance(dataset, MapDataset):
self.sampler = (
sampler
if sampler
else SequentialSampler(dataset, batch_size=1, drop_last=False)
)
assert isinstance(
self.sampler, MapSampler
), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else:
self.sampler = sampler
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if divide:
if self.sampler.batch_size <= self.num_workers:
......@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.post_process = self.dataset.post_process
def _get_next_batch(self):
raise NotImplementedError
......@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
return self
def __next__(self):
return self.post_process(self._get_next_batch())
return self._get_next_batch()
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.data = None
def _get_next_batch(self):
ret = []
......@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if self.idx != 0:
data = self.data
else:
try:
raw_data = next(self.dataset_iter)
except:
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break
return self.collator.apply(ret)
......@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self.shutdown_flag = multiprocessing.Value("i", 0)
self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
# shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
self.recieve_worker.start()
self.transform_workers = []
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
target=self._transform, args=(worker_id,), daemon=True
)
worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()
self.transform_workers.append(worker)
self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
self.collect_worker.start()
self.__initialized = True
def _gen_data(self, worker_id):
def _recieve(self):
dataset_iter = iter(self.dataset)
cnt = -1
while True:
if self.shutdown_flag.value == 1:
break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
raw_data = next(dataset_iter)
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True:
cnt += 1
qid = cnt % self.num_workers
try:
self.worker_queues[worker_id].put(trans_item)
self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full")
logger.debug("raw data queue is full")
def _gen_batch(self):
def _transform(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")
def _collect(self):
cnt = -1
trans_items = []
while True:
......@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.worker_queues[queue_id].get(
trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
except queue.Empty:
......@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
trans_items = []
def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if not self.collect_worker.is_alive():
exitcode = self.collect_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
......@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()
if self.recieve_worker.is_alive():
self.recieve_worker.terminate()
self.recieve_worker.join()
for worker in self.workers:
if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()
for worker in self.transform_workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.worker_queues:
for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()
for q in self.trans_data_queues:
q.cancel_join_thread()
q.close()
......
......@@ -161,10 +161,13 @@ class StreamSampler(Sampler):
.. warning::
In the case of multiple workers, sampler should ensure that each worker gets
In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.
"""
def __init__(self, batch_size=1):
......@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
return self
def __next__(self):
return range(self.batch_size)
return iter(range(self.batch_size))
class SequentialSampler(MapSampler):
......
......@@ -15,9 +15,15 @@ import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset
from megengine.data.sampler import RandomSampler, SequentialSampler
from megengine.data.transform import PseudoTransform, Transform
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)
def init_dataset():
......@@ -54,6 +60,80 @@ def test_dataloader_init():
assert len(dataloader) == 16
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
self.number = number
self.batch = batch
self.error = error
def __iter__(self):
for cnt in range(self.number):
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 32, 32)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
sampler = StreamSampler(batch_size=4)
class TimeoutTransform(Transform):
def __init__(self):
pass
def apply(self, input):
time.sleep(10)
return input
dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)
def test_dataloader_serial():
dataset = init_dataset()
dataloader = DataLoader(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册