提交 f04e0d77 编写于 作者: M Megvii Engine Team

feat(mge/data): dpflow dataset, stream sampler and loader

GitOrigin-RevId: cbb4510a13625e7c2203cd1358a96208849029ca
上级 be511a56
...@@ -14,4 +14,5 @@ from .sampler import ( ...@@ -14,4 +14,5 @@ from .sampler import (
ReplacementSampler, ReplacementSampler,
Sampler, Sampler,
SequentialSampler, SequentialSampler,
StreamSampler,
) )
...@@ -19,8 +19,8 @@ import numpy as np ...@@ -19,8 +19,8 @@ import numpy as np
from ..logger import get_logger from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler from .sampler import Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -82,13 +82,21 @@ class DataLoader: ...@@ -82,13 +82,21 @@ class DataLoader:
raise ValueError("divide should not be set to True when num_workers <= 1") raise ValueError("divide should not be set to True when num_workers <= 1")
self.dataset = dataset self.dataset = dataset
self.num_workers = num_workers self.num_workers = num_workers
self.timeout = timeout self.timeout = timeout
self.divide = divide self.divide = divide
if sampler is None: if sampler is None:
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
else: else:
self.sampler = sampler self.sampler = sampler
...@@ -120,16 +128,26 @@ class DataLoader: ...@@ -120,16 +128,26 @@ class DataLoader:
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
) )
self.num_workers = 0 self.num_workers = 0
if self.num_workers == 0: if isinstance(self.dataset, StreamDataset):
return _SerialDataLoaderIter(self) if not self.num_workers:
return _SerialStreamDataLoaderIter(self)
else:
return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
else: else:
return _ParallelDataLoaderIter(self) return _ParallelMapDataLoaderIter(self)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)
def __len__(self): def __len__(self):
return len(self.sampler) return len(self.sampler)
class _BaseDataLoaderIter: class _BaseMapDataLoaderIter:
def __init__(self, loader): def __init__(self, loader):
self.dataset = loader.dataset self.dataset = loader.dataset
self.sampler = loader.sampler self.sampler = loader.sampler
...@@ -158,9 +176,9 @@ class _BaseDataLoaderIter: ...@@ -158,9 +176,9 @@ class _BaseDataLoaderIter:
return minibatch return minibatch
class _SerialDataLoaderIter(_BaseDataLoaderIter): class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
super(_SerialDataLoaderIter, self).__init__(loader) super(_SerialMapDataLoaderIter, self).__init__(loader)
self.indices_iter = iter(self.sampler) self.indices_iter = iter(self.sampler)
def _get_next_batch(self): def _get_next_batch(self):
...@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter): ...@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
return self.collator.apply(trans_items) return self.collator.apply(trans_items)
class _ParallelDataLoaderIter(_BaseDataLoaderIter): class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False __initialized = False
def __init__(self, loader): def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader) super(_ParallelMapDataLoaderIter, self).__init__(loader)
self.task_queues = [ self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
...@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter): ...@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self._shutdown() self._shutdown()
class _BaseStreamDataLoaderIter:
def __init__(self, loader):
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.post_process = self.dataset.post_process
def _get_next_batch(self):
raise NotImplementedError
def __iter__(self):
return self
def __next__(self):
return self.post_process(self._get_next_batch())
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)
def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if len(ret) == self.sampler.batch_size:
break
return self.collator.apply(ret)
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False
def __init__(self, loader):
super().__init__(loader)
self.shutdown_flag = multiprocessing.Value("i", 0)
# shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue
self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
)
worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()
self.__initialized = True
def _gen_data(self, worker_id):
dataset_iter = iter(self.dataset)
while True:
if self.shutdown_flag.value == 1:
break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
while True:
try:
self.worker_queues[worker_id].put(trans_item)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full")
def _gen_batch(self):
cnt = -1
trans_items = []
while True:
if self.shutdown_flag.value == 1:
break
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.worker_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
except queue.Empty:
continue
trans_items.append(trans_item)
if len(trans_items) == self.sampler.batch_size:
batch_data = self.collator.apply(trans_items)
while True:
try:
self.batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full")
trans_items = []
def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))
for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker: {} died. {}".format(worker_id, exitcode)
)
def _try_get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data
def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1
if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()
for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()
for q in self.worker_queues:
q.cancel_join_thread()
q.close()
self.batch_queue.cancel_join_thread()
self.batch_queue.close()
def __del__(self):
if self.__initialized:
self._shutdown()
def _task_feeding_loop( def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
): ):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections.abc import collections.abc
import math import math
from abc import ABC from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union from typing import Any, Generator, Iterator, List, Union
import numpy as np import numpy as np
...@@ -17,6 +17,16 @@ import megengine.distributed as dist ...@@ -17,6 +17,16 @@ import megengine.distributed as dist
class Sampler(ABC): class Sampler(ABC):
r"""
An abstract class for all Sampler
"""
@abstractmethod
def __init__(self):
pass
class MapSampler(Sampler):
def __init__( def __init__(
self, self,
dataset, dataset,
...@@ -145,7 +155,29 @@ class Sampler(ABC): ...@@ -145,7 +155,29 @@ class Sampler(ABC):
return iter(batch_index) return iter(batch_index)
class SequentialSampler(Sampler): class StreamSampler(Sampler):
"""
Sampler for stream dataset.
.. warning::
In the case of multiple workers, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
"""
def __init__(self, batch_size=1):
self.batch_size = batch_size
def __iter__(self):
return self
def __next__(self):
return range(self.batch_size)
class SequentialSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
...@@ -176,7 +208,7 @@ class SequentialSampler(Sampler): ...@@ -176,7 +208,7 @@ class SequentialSampler(Sampler):
return self.indices return self.indices
class RandomSampler(Sampler): class RandomSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
...@@ -205,7 +237,7 @@ class RandomSampler(Sampler): ...@@ -205,7 +237,7 @@ class RandomSampler(Sampler):
return self.rng.permutation(self.indices).tolist() return self.rng.permutation(self.indices).tolist()
class ReplacementSampler(Sampler): class ReplacementSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
...@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler): ...@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler):
return self.rng.multinomial(n, self.weights, self.num_samples).tolist() return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
class Infinite(Sampler): class Infinite(MapSampler):
r"""Infinite Sampler warper for basic sampler.""" r"""Infinite Sampler warper for basic sampler."""
def sample(self): def sample(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册