提交 f04e0d77 编写于 作者: M Megvii Engine Team

feat(mge/data): dpflow dataset, stream sampler and loader

GitOrigin-RevId: cbb4510a13625e7c2203cd1358a96208849029ca
上级 be511a56
......@@ -14,4 +14,5 @@ from .sampler import (
ReplacementSampler,
Sampler,
SequentialSampler,
StreamSampler,
)
......@@ -19,8 +19,8 @@ import numpy as np
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from .collator import Collator
from .dataset import Dataset
from .sampler import Sampler, SequentialSampler
from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform
logger = get_logger(__name__)
......@@ -82,13 +82,21 @@ class DataLoader:
raise ValueError("divide should not be set to True when num_workers <= 1")
self.dataset = dataset
self.num_workers = num_workers
self.timeout = timeout
self.divide = divide
if sampler is None:
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
else:
self.sampler = sampler
......@@ -120,16 +128,26 @@ class DataLoader:
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
)
self.num_workers = 0
if self.num_workers == 0:
return _SerialDataLoaderIter(self)
if isinstance(self.dataset, StreamDataset):
if not self.num_workers:
return _SerialStreamDataLoaderIter(self)
else:
return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
else:
return _ParallelMapDataLoaderIter(self)
else:
return _ParallelDataLoaderIter(self)
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)
def __len__(self):
return len(self.sampler)
class _BaseDataLoaderIter:
class _BaseMapDataLoaderIter:
def __init__(self, loader):
self.dataset = loader.dataset
self.sampler = loader.sampler
......@@ -158,9 +176,9 @@ class _BaseDataLoaderIter:
return minibatch
class _SerialDataLoaderIter(_BaseDataLoaderIter):
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader):
super(_SerialDataLoaderIter, self).__init__(loader)
super(_SerialMapDataLoaderIter, self).__init__(loader)
self.indices_iter = iter(self.sampler)
def _get_next_batch(self):
......@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
return self.collator.apply(trans_items)
class _ParallelDataLoaderIter(_BaseDataLoaderIter):
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False
def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader)
super(_ParallelMapDataLoaderIter, self).__init__(loader)
self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
......@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self._shutdown()
class _BaseStreamDataLoaderIter:
def __init__(self, loader):
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.post_process = self.dataset.post_process
def _get_next_batch(self):
raise NotImplementedError
def __iter__(self):
return self
def __next__(self):
return self.post_process(self._get_next_batch())
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if len(ret) == self.sampler.batch_size:
break
return self.collator.apply(ret)
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False
def __init__(self, loader):
super().__init__(loader)
self.shutdown_flag = multiprocessing.Value("i", 0)
# shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
)
worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()
self.__initialized = True
def _gen_data(self, worker_id):
dataset_iter = iter(self.dataset)
while True:
if self.shutdown_flag.value == 1:
break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
while True:
try:
self.worker_queues[worker_id].put(trans_item)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full")
def _gen_batch(self):
cnt = -1
trans_items = []
while True:
if self.shutdown_flag.value == 1:
break
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.worker_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
except queue.Empty:
continue
trans_items.append(trans_item)
if len(trans_items) == self.sampler.batch_size:
batch_data = self.collator.apply(trans_items)
while True:
try:
self.batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full")
trans_items = []
def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker: {} died. {}".format(worker_id, exitcode)
)
def _try_get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.worker_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
):
......
......@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections.abc
import math
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union
import numpy as np
......@@ -17,6 +17,16 @@ import megengine.distributed as dist
class Sampler(ABC):
r"""
An abstract class for all Sampler
"""
@abstractmethod
def __init__(self):
pass
class MapSampler(Sampler):
def __init__(
self,
dataset,
......@@ -145,7 +155,29 @@ class Sampler(ABC):
return iter(batch_index)
class SequentialSampler(Sampler):
class StreamSampler(Sampler):
"""
Sampler for stream dataset.
.. warning::
In the case of multiple workers, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
"""
def __init__(self, batch_size=1):
self.batch_size = batch_size
def __iter__(self):
return self
def __next__(self):
return range(self.batch_size)
class SequentialSampler(MapSampler):
def __init__(
self,
dataset,
......@@ -176,7 +208,7 @@ class SequentialSampler(Sampler):
return self.indices
class RandomSampler(Sampler):
class RandomSampler(MapSampler):
def __init__(
self,
dataset,
......@@ -205,7 +237,7 @@ class RandomSampler(Sampler):
return self.rng.permutation(self.indices).tolist()
class ReplacementSampler(Sampler):
class ReplacementSampler(MapSampler):
def __init__(
self,
dataset,
......@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler):
return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
class Infinite(Sampler):
class Infinite(MapSampler):
r"""Infinite Sampler warper for basic sampler."""
def sample(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册