提交 e082e277 编写于 作者: M Megvii Engine Team

refactor(mge/data): Refactor megeninge.data.dataset

GitOrigin-RevId: 1d9c61ce70059de9e7e0f804a67f48e54d3891a6
上级 f04e0d77
...@@ -10,6 +10,7 @@ from .collator import Collator ...@@ -10,6 +10,7 @@ from .collator import Collator
from .dataloader import DataLoader from .dataloader import DataLoader
from .sampler import ( from .sampler import (
Infinite, Infinite,
MapSampler,
RandomSampler, RandomSampler,
ReplacementSampler, ReplacementSampler,
Sampler, Sampler,
......
...@@ -20,7 +20,7 @@ from ..logger import get_logger ...@@ -20,7 +20,7 @@ from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset, MapDataset, StreamDataset from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -88,17 +88,24 @@ class DataLoader: ...@@ -88,17 +88,24 @@ class DataLoader:
self.divide = divide self.divide = divide
if sampler is None: if isinstance(dataset, MapDataset):
if isinstance(dataset, MapDataset): self.sampler = (
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) sampler
elif isinstance(dataset, StreamDataset): if sampler
self.sampler = StreamSampler(batch_size=1) else SequentialSampler(dataset, batch_size=1, drop_last=False)
else: )
raise TypeError( assert isinstance(
"can not recognize this kind of dataset: %s" % type(dataset) self.sampler, MapSampler
) ), "types of dataset and sampler do not match"
elif isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
assert isinstance(
self.sampler, StreamSampler
), "types of dataset and sampler do not match"
else: else:
self.sampler = sampler raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
if divide: if divide:
if self.sampler.batch_size <= self.num_workers: if self.sampler.batch_size <= self.num_workers:
...@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter: ...@@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter:
self.collator = loader.collator self.collator = loader.collator
self.num_workers = loader.num_workers self.num_workers = loader.num_workers
self.timeout = loader.timeout self.timeout = loader.timeout
self.post_process = self.dataset.post_process
def _get_next_batch(self): def _get_next_batch(self):
raise NotImplementedError raise NotImplementedError
...@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter: ...@@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter:
return self return self
def __next__(self): def __next__(self):
return self.post_process(self._get_next_batch()) return self._get_next_batch()
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
super().__init__(loader) super().__init__(loader)
self.dataset_iter = iter(self.dataset) self.dataset_iter = iter(self.dataset)
self.idx = 0
self.data = None
def _get_next_batch(self): def _get_next_batch(self):
ret = [] ret = []
...@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
waited_time = time.time() - start_time waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout: if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!") raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter) if self.idx != 0:
for idx in range(len(item[0])): data = self.data
trans_item = self.transform.apply(tuple(e[idx] for e in item)) else:
ret.append(trans_item) try:
raw_data = next(self.dataset_iter)
except:
continue
assert len(raw_data) == 2 and isinstance(
raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(self.idx, len(data[0])):
trans_data = self.transform.apply(tuple(e[idx] for e in data))
ret.append(trans_data)
if len(ret) == self.sampler.batch_size: if len(ret) == self.sampler.batch_size:
if idx + 1 == len(data[0]):
self.idx = 0
self.data = None
else:
self.idx = idx
self.data = data
break break
return self.collator.apply(ret) return self.collator.apply(ret)
...@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self.shutdown_flag = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0)
self.raw_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
self.trans_data_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
# shared-memory queue implemented by pyarrow plasma store # shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue from ._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2) self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [ self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) self.recieve_worker.start()
]
self.transform_workers = []
for worker_id in range(self.num_workers): for worker_id in range(self.num_workers):
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True target=self._transform, args=(worker_id,), daemon=True
) )
worker.start() worker.start()
self.workers.append(worker) self.transform_workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
) self.collect_worker.start()
self.collator_worker.start()
self.__initialized = True self.__initialized = True
def _gen_data(self, worker_id): def _recieve(self):
dataset_iter = iter(self.dataset) dataset_iter = iter(self.dataset)
cnt = -1
while True: while True:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
item = next(dataset_iter) raw_data = next(dataset_iter)
for idx in range(len(item[0])): assert len(raw_data) == 2 and isinstance(
trans_item = self.transform.apply(tuple(e[idx] for e in item)) raw_data[0], bool
), "raw_data must be a tuple"
if not raw_data[0]:
data = list((x,) for x in raw_data[1])
else:
data = raw_data[1]
for idx in range(len(data[0])):
while True: while True:
cnt += 1
qid = cnt % self.num_workers
try: try:
self.worker_queues[worker_id].put(trans_item) self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
break break
except queue.Full: except queue.Full:
if self.shutdown_flag.value == 1: if self.shutdown_flag.value == 1:
break break
logger.debug("batch part queue is full") logger.debug("raw data queue is full")
def _gen_batch(self): def _transform(self, worker_id):
while True:
if self.shutdown_flag.value == 1:
break
try:
data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT)
except queue.Empty:
continue
trans_data = self.transform.apply(data)
while True:
try:
self.trans_data_queues[worker_id].put(trans_data)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue if full")
def _collect(self):
cnt = -1 cnt = -1
trans_items = [] trans_items = []
while True: while True:
...@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
cnt += 1 cnt += 1
queue_id = cnt % self.num_workers queue_id = cnt % self.num_workers
try: try:
trans_item = self.worker_queues[queue_id].get( trans_item = self.trans_data_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT timeout=MP_QUEUE_GET_TIMEOUT
) )
except queue.Empty: except queue.Empty:
...@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
trans_items = [] trans_items = []
def _check_workers(self): def _check_workers(self):
if not self.collator_worker.is_alive(): if not self.collect_worker.is_alive():
exitcode = self.collator_worker.exitcode exitcode = self.collect_worker.exitcode
if exitcode != 0: if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode)) raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers): for worker_id, worker in enumerate(self.transform_workers):
if not worker.is_alive(): if not worker.is_alive():
exitcode = worker.exitcode exitcode = worker.exitcode
if exitcode != 0: if exitcode != 0:
...@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
with self.shutdown_flag.get_lock(): with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1 self.shutdown_flag.value = 1
if self.collator_worker.is_alive(): if self.recieve_worker.is_alive():
self.collator_worker.terminate() self.recieve_worker.terminate()
self.collator_worker.join() self.recieve_worker.join()
for worker in self.workers: if self.collect_worker.is_alive():
self.collect_worker.terminate()
self.collect_worker.join()
for worker in self.transform_workers:
if worker.is_alive(): if worker.is_alive():
worker.terminate() worker.terminate()
worker.join() worker.join()
for q in self.worker_queues: for q in self.raw_data_queues:
q.cancel_join_thread()
q.close()
for q in self.trans_data_queues:
q.cancel_join_thread() q.cancel_join_thread()
q.close() q.close()
......
...@@ -161,10 +161,13 @@ class StreamSampler(Sampler): ...@@ -161,10 +161,13 @@ class StreamSampler(Sampler):
.. warning:: .. warning::
In the case of multiple workers, sampler should ensure that each worker gets In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal. dataset and sampler to achieve this goal.
Usually, meth::`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.
""" """
def __init__(self, batch_size=1): def __init__(self, batch_size=1):
...@@ -174,7 +177,7 @@ class StreamSampler(Sampler): ...@@ -174,7 +177,7 @@ class StreamSampler(Sampler):
return self return self
def __next__(self): def __next__(self):
return range(self.batch_size) return iter(range(self.batch_size))
class SequentialSampler(MapSampler): class SequentialSampler(MapSampler):
......
...@@ -15,9 +15,15 @@ import pytest ...@@ -15,9 +15,15 @@ import pytest
from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import PseudoTransform, Transform from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)
def init_dataset(): def init_dataset():
...@@ -54,6 +60,80 @@ def test_dataloader_init(): ...@@ -54,6 +60,80 @@ def test_dataloader_init():
assert len(dataloader) == 16 assert len(dataloader) == 16
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error=False):
self.number = number
self.batch = batch
self.error = error
def __iter__(self):
for cnt in range(self.number):
if self.batch:
data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8")
if self.error:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 32, 32)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False)
sampler = StreamSampler(batch_size=4)
class TimeoutTransform(Transform):
def __init__(self):
pass
def apply(self, input):
time.sleep(10)
return input
dataloader = DataLoader(
dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)
def test_dataloader_serial(): def test_dataloader_serial():
dataset = init_dataset() dataset = init_dataset()
dataloader = DataLoader( dataloader = DataLoader(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册