From e082e27780451a9c235debe26e72e398dd2b9bc2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 5 Nov 2020 20:39:04 +0800 Subject: [PATCH] refactor(mge/data): Refactor megeninge.data.dataset GitOrigin-RevId: 1d9c61ce70059de9e7e0f804a67f48e54d3891a6 --- imperative/python/megengine/data/__init__.py | 1 + .../python/megengine/data/dataloader.py | 156 +++++++++++++----- imperative/python/megengine/data/sampler.py | 7 +- .../python/test/unit/data/test_dataloader.py | 86 +++++++++- 4 files changed, 202 insertions(+), 48 deletions(-) diff --git a/imperative/python/megengine/data/__init__.py b/imperative/python/megengine/data/__init__.py index 11398efe0..308a89077 100644 --- a/imperative/python/megengine/data/__init__.py +++ b/imperative/python/megengine/data/__init__.py @@ -10,6 +10,7 @@ from .collator import Collator from .dataloader import DataLoader from .sampler import ( Infinite, + MapSampler, RandomSampler, ReplacementSampler, Sampler, diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index d6c55422e..31cce8da4 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -20,7 +20,7 @@ from ..logger import get_logger from ..random.rng import _random_seed_generator from .collator import Collator from .dataset import Dataset, MapDataset, StreamDataset -from .sampler import Sampler, SequentialSampler, StreamSampler +from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler from .transform import PseudoTransform, Transform logger = get_logger(__name__) @@ -88,17 +88,24 @@ class DataLoader: self.divide = divide - if sampler is None: - if isinstance(dataset, MapDataset): - self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False) - elif isinstance(dataset, StreamDataset): - self.sampler = StreamSampler(batch_size=1) - else: - raise TypeError( - "can not recognize this kind of dataset: %s" % type(dataset) - ) + if isinstance(dataset, MapDataset): + self.sampler = ( + sampler + if sampler + else SequentialSampler(dataset, batch_size=1, drop_last=False) + ) + assert isinstance( + self.sampler, MapSampler + ), "types of dataset and sampler do not match" + elif isinstance(dataset, StreamDataset): + self.sampler = sampler if sampler else StreamSampler(batch_size=1) + assert isinstance( + self.sampler, StreamSampler + ), "types of dataset and sampler do not match" else: - self.sampler = sampler + raise TypeError( + "can not recognize this kind of dataset: %s" % type(dataset) + ) if divide: if self.sampler.batch_size <= self.num_workers: @@ -352,7 +359,6 @@ class _BaseStreamDataLoaderIter: self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout - self.post_process = self.dataset.post_process def _get_next_batch(self): raise NotImplementedError @@ -361,13 +367,15 @@ class _BaseStreamDataLoaderIter: return self def __next__(self): - return self.post_process(self._get_next_batch()) + return self._get_next_batch() class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): def __init__(self, loader): super().__init__(loader) self.dataset_iter = iter(self.dataset) + self.idx = 0 + self.data = None def _get_next_batch(self): ret = [] @@ -376,11 +384,30 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): waited_time = time.time() - start_time if self.timeout > 0 and waited_time > self.timeout: raise RuntimeError("get_next_batch timeout!") - item = next(self.dataset_iter) - for idx in range(len(item[0])): - trans_item = self.transform.apply(tuple(e[idx] for e in item)) - ret.append(trans_item) + if self.idx != 0: + data = self.data + else: + try: + raw_data = next(self.dataset_iter) + except: + continue + assert len(raw_data) == 2 and isinstance( + raw_data[0], bool + ), "raw_data must be a tuple" + if not raw_data[0]: + data = list((x,) for x in raw_data[1]) + else: + data = raw_data[1] + for idx in range(self.idx, len(data[0])): + trans_data = self.transform.apply(tuple(e[idx] for e in data)) + ret.append(trans_data) if len(ret) == self.sampler.batch_size: + if idx + 1 == len(data[0]): + self.idx = 0 + self.data = None + else: + self.idx = idx + self.data = data break return self.collator.apply(ret) @@ -393,45 +420,80 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): self.shutdown_flag = multiprocessing.Value("i", 0) + self.raw_data_queues = [ + multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) + ] + + self.trans_data_queues = [ + multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) + ] + # shared-memory queue implemented by pyarrow plasma store from ._queue import PlasmaShmQueue self.batch_queue = PlasmaShmQueue(maxsize=2) - self.workers = [] - self.worker_queues = [ - multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) - ] + + self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True) + self.recieve_worker.start() + + self.transform_workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( - target=self._gen_data, args=(worker_id,), daemon=True + target=self._transform, args=(worker_id,), daemon=True ) worker.start() - self.workers.append(worker) - self.collator_worker = multiprocessing.Process( - target=self._gen_batch, daemon=True - ) - self.collator_worker.start() + self.transform_workers.append(worker) + + self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True) + self.collect_worker.start() self.__initialized = True - def _gen_data(self, worker_id): + def _recieve(self): dataset_iter = iter(self.dataset) + cnt = -1 while True: if self.shutdown_flag.value == 1: break - item = next(dataset_iter) - for idx in range(len(item[0])): - trans_item = self.transform.apply(tuple(e[idx] for e in item)) + raw_data = next(dataset_iter) + assert len(raw_data) == 2 and isinstance( + raw_data[0], bool + ), "raw_data must be a tuple" + if not raw_data[0]: + data = list((x,) for x in raw_data[1]) + else: + data = raw_data[1] + for idx in range(len(data[0])): while True: + cnt += 1 + qid = cnt % self.num_workers try: - self.worker_queues[worker_id].put(trans_item) + self.raw_data_queues[qid].put(tuple(e[idx] for e in data)) break except queue.Full: if self.shutdown_flag.value == 1: break - logger.debug("batch part queue is full") + logger.debug("raw data queue is full") - def _gen_batch(self): + def _transform(self, worker_id): + while True: + if self.shutdown_flag.value == 1: + break + try: + data = self.raw_data_queues[worker_id].get(timeout=MP_QUEUE_GET_TIMEOUT) + except queue.Empty: + continue + trans_data = self.transform.apply(data) + while True: + try: + self.trans_data_queues[worker_id].put(trans_data) + break + except queue.Full: + if self.shutdown_flag.value == 1: + break + logger.debug("batch queue if full") + + def _collect(self): cnt = -1 trans_items = [] while True: @@ -440,7 +502,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): cnt += 1 queue_id = cnt % self.num_workers try: - trans_item = self.worker_queues[queue_id].get( + trans_item = self.trans_data_queues[queue_id].get( timeout=MP_QUEUE_GET_TIMEOUT ) except queue.Empty: @@ -459,12 +521,12 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): trans_items = [] def _check_workers(self): - if not self.collator_worker.is_alive(): - exitcode = self.collator_worker.exitcode + if not self.collect_worker.is_alive(): + exitcode = self.collect_worker.exitcode if exitcode != 0: raise RuntimeError("collator worker died. {}".format(exitcode)) - for worker_id, worker in enumerate(self.workers): + for worker_id, worker in enumerate(self.transform_workers): if not worker.is_alive(): exitcode = worker.exitcode if exitcode != 0: @@ -492,16 +554,24 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): with self.shutdown_flag.get_lock(): self.shutdown_flag.value = 1 - if self.collator_worker.is_alive(): - self.collator_worker.terminate() - self.collator_worker.join() + if self.recieve_worker.is_alive(): + self.recieve_worker.terminate() + self.recieve_worker.join() - for worker in self.workers: + if self.collect_worker.is_alive(): + self.collect_worker.terminate() + self.collect_worker.join() + + for worker in self.transform_workers: if worker.is_alive(): worker.terminate() worker.join() - for q in self.worker_queues: + for q in self.raw_data_queues: + q.cancel_join_thread() + q.close() + + for q in self.trans_data_queues: q.cancel_join_thread() q.close() diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index a260fa87e..3d6f4bafa 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -161,10 +161,13 @@ class StreamSampler(Sampler): .. warning:: - In the case of multiple workers, sampler should ensure that each worker gets + In the case of multiple machines, sampler should ensure that each worker gets different data. But this class cannot do it yet, please build your own dataset and sampler to achieve this goal. + Usually, meth::`~.StreamDataset.__iter__` can return different iterator by + ``rank = dist.get_rank()``. So that they will get different data. + """ def __init__(self, batch_size=1): @@ -174,7 +177,7 @@ class StreamSampler(Sampler): return self def __next__(self): - return range(self.batch_size) + return iter(range(self.batch_size)) class SequentialSampler(MapSampler): diff --git a/imperative/python/test/unit/data/test_dataloader.py b/imperative/python/test/unit/data/test_dataloader.py index 446dce07c..a8f152d37 100644 --- a/imperative/python/test/unit/data/test_dataloader.py +++ b/imperative/python/test/unit/data/test_dataloader.py @@ -15,9 +15,15 @@ import pytest from megengine.data.collator import Collator from megengine.data.dataloader import DataLoader -from megengine.data.dataset import ArrayDataset -from megengine.data.sampler import RandomSampler, SequentialSampler -from megengine.data.transform import PseudoTransform, Transform +from megengine.data.dataset import ArrayDataset, StreamDataset +from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler +from megengine.data.transform import ( + Compose, + Normalize, + PseudoTransform, + ToMode, + Transform, +) def init_dataset(): @@ -54,6 +60,80 @@ def test_dataloader_init(): assert len(dataloader) == 16 +class MyStream(StreamDataset): + def __init__(self, number, batch=False, error=False): + self.number = number + self.batch = batch + self.error = error + + def __iter__(self): + for cnt in range(self.number): + if self.batch: + data = np.random.randint(0, 256, (2, 32, 32, 3), dtype="uint8") + yield (True, (data, [cnt, cnt - self.number])) + else: + data = np.random.randint(0, 256, (32, 32, 3), dtype="uint8") + if self.error: + yield (data, cnt) + else: + yield (False, (data, cnt)) + raise StopIteration + + +@pytest.mark.parametrize("batch", [True, False]) +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_stream_dataloader(batch, num_workers): + dataset = MyStream(100, batch) + sampler = StreamSampler(batch_size=4) + dataloader = DataLoader( + dataset, + sampler, + Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), + num_workers=num_workers, + ) + + check_set = set() + + for step, data in enumerate(dataloader): + if step == 10: + break + assert data[0].shape == (4, 3, 32, 32) + assert data[1].shape == (4,) + for i in data[1]: + assert i not in check_set + check_set.add(i) + + +def test_stream_dataloader_error(): + dataset = MyStream(100, error=True) + sampler = StreamSampler(batch_size=4) + dataloader = DataLoader(dataset, sampler) + with pytest.raises(AssertionError, match=r".*tuple.*"): + data_iter = iter(dataloader) + next(data_iter) + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_stream_dataloader_timeout(num_workers): + dataset = MyStream(100, False) + sampler = StreamSampler(batch_size=4) + + class TimeoutTransform(Transform): + def __init__(self): + pass + + def apply(self, input): + time.sleep(10) + return input + + dataloader = DataLoader( + dataset, sampler, TimeoutTransform(), num_workers=num_workers, timeout=5 + ) + with pytest.raises(RuntimeError, match=r".*timeout.*"): + data_iter = iter(dataloader) + next(data_iter) + + def test_dataloader_serial(): dataset = init_dataset() dataloader = DataLoader( -- GitLab