From dff7719e87e7fd2fe7721042995e0f6b775a7a28 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 10 Sep 2021 11:28:25 +0800 Subject: [PATCH] feat(mge/distributed): add preload host data with op fix(mge/distributed): change api name with preload fix(mge/distributed): fix recursive model in preload tensor fix(mge/distributed): fix recursive when cache contain None GitOrigin-RevId: 80e2a6dd7010f48a279b0d0be8ae6d5aad53f011 --- .../python/megengine/data/dataloader.py | 134 ++++++-- imperative/python/megengine/device.py | 13 + .../python/megengine/distributed/group.py | 10 +- .../test/unit/data/test_pre_dataloader.py | 308 ++++++++++++++++++ 4 files changed, 436 insertions(+), 29 deletions(-) create mode 100644 imperative/python/test/unit/data/test_pre_dataloader.py diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 69a835ccb..e63b36b94 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -15,12 +15,15 @@ import queue import random import threading import time -from typing import Callable +from typing import Callable, Union import numpy as np +from ..device import _sh, get_default_device +from ..functional.tensor import copy from ..logger import get_logger from ..random.rng import _random_seed_generator +from ..tensor import Tensor from .collator import Collator from .dataset import Dataset, StreamDataset from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler @@ -44,7 +47,7 @@ def raise_timeout_error(): class DataLoader: r"""Provides a convenient way to iterate on a given dataset. - + DataLoader combines a dataset with :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, make it flexible to get minibatch continually from a dataset. @@ -66,6 +69,8 @@ class DataLoader: ``True`` means one batch is divided into :attr:`num_workers` pieces, and the workers will process these pieces parallelly. ``False`` means different sub-process will process different batch. Default: False + preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False + the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. """ __initialized = False @@ -79,6 +84,7 @@ class DataLoader: timeout: int = 0, timeout_event: Callable = raise_timeout_error, divide: bool = False, + preload: bool = False, ): if num_workers < 0: raise ValueError("num_workers should not be negative") @@ -96,6 +102,7 @@ class DataLoader: self.timeout_event = timeout_event self.divide = divide + self.preload = preload if isinstance(dataset, StreamDataset): self.sampler = sampler if sampler else StreamSampler(batch_size=1) @@ -145,24 +152,74 @@ class DataLoader: self.num_workers = 0 if isinstance(self.dataset, StreamDataset): if not self.num_workers: - return _SerialStreamDataLoaderIter(self) + return _SerialStreamDataLoaderIter(self, self.preload) else: - return _ParallelStreamDataLoaderIter(self) + return _ParallelStreamDataLoaderIter(self, self.preload) else: assert isinstance( self.dataset, Dataset ), "Can not recognize this kind of dataset: %s" % type(self.dataset) if not self.num_workers: - return _SerialMapDataLoaderIter(self) + return _SerialMapDataLoaderIter(self, self.preload) else: - return _ParallelMapDataLoaderIter(self) + return _ParallelMapDataLoaderIter(self, self.preload) def __len__(self): return len(self.sampler) -class _BaseMapDataLoaderIter: - def __init__(self, loader): +class PreLoader: + def __init__(self, preload): + if preload: + self.default_device = get_default_device() + self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) + self.pre_load_device_cache = None + self.preload = preload + + """ + strategy one: load from numpy data, and generate dtype tensor + """ + + def _load_tensor(self, batch, cached=True): + if isinstance(batch, np.ndarray): + device = self.pre_load_device if cached else self.default_device + return Tensor(batch, device=device) + elif isinstance(batch, collections.abc.Mapping): + return {k: self._load_tensor(v, cached) for k, v in batch.items()} + elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple + return type(batch)(*(self._load_tensor(value, cached) for value in batch)) + elif isinstance(batch, collections.abc.Sequence): + return [self._load_tensor(value, cached) for value in batch] + else: + return batch + + """ + strategy two: load from cache that is already tensor just do d2d copy + """ + + def _load_cache(self, data): + if isinstance(data, Tensor): + if data.device == self.default_device: + return data + return copy(data, device=self.default_device) + elif isinstance(data, collections.abc.Mapping): + return {k: self._load_cache(v) for k, v in data.items()} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return type(data)(*(self._load_cache(value) for value in data)) + elif isinstance(data, collections.abc.Sequence): + return [self._load_cache(value) for value in data] + else: + return data + + def _swap_out_cache(self): + out = self._load_cache(self.pre_load_device_cache) + self.pre_load_device_cache = None # clean cache + return out + + +class _BaseMapDataLoaderIter(PreLoader): + def __init__(self, loader, preload): + super().__init__(preload) self.dataset = loader.dataset self.sampler = loader.sampler self.seed = _random_seed_generator().__next__() @@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter: return self def __next__(self): + if self.preload: + cached = self.pre_load_device_cache + if cached is None: # first and last + if self.num_processed >= len(self): # last + raise StopIteration + elif self.num_processed == 0: # first + self._try_load_tensor(cached=False) # first do the h2d + out = self._swap_out_cache() + self._try_load_tensor() + return out + else: + if self.num_processed >= len(self): + raise StopIteration + minibatch = self._get_next_batch() + self.num_processed += 1 + return minibatch + + def _try_load_tensor(self, cached=True): if self.num_processed >= len(self): - raise StopIteration - minibatch = self._get_next_batch() - self.num_processed += 1 - return minibatch + return + else: + self.num_processed += 1 + batch = self._get_next_batch() + self.pre_load_device_cache = self._load_tensor(batch, cached) class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): - def __init__(self, loader): - super(_SerialMapDataLoaderIter, self).__init__(loader) + def __init__(self, loader, preload): + super(_SerialMapDataLoaderIter, self).__init__(loader, preload) self.indices_iter = iter(self.sampler) def _get_next_batch(self): @@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): __initialized = False - def __init__(self, loader): - super(_ParallelMapDataLoaderIter, self).__init__(loader) + def __init__(self, loader, preload): + super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) self.task_queues = [ multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) @@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): self._shutdown() -class _BaseStreamDataLoaderIter: - def __init__(self, loader): +class _BaseStreamDataLoaderIter(PreLoader): + def __init__(self, loader, preload): + super().__init__(preload) self.dataset = loader.dataset self.sampler = loader.sampler self.transform = loader.transform @@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter: return self def __next__(self): - return self._get_next_batch() + if self.preload: + if self.pre_load_device_cache is None: + self._try_load_tensor(cached=False) # load in current + out = self._swap_out_cache() + self._try_load_tensor() # load in cached + return out + else: + return self._get_next_batch() + + def _try_load_tensor(self, cached=True): + batch = self._get_next_batch() + self.pre_load_device_cache = self._load_tensor(batch, cached) class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): - def __init__(self, loader): - super().__init__(loader) + def __init__(self, loader, preload): + super().__init__(loader, preload) self.dataset_iter = iter(self.dataset) self.idx = 0 self.unused = [] @@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): __initialized = False - def __init__(self, loader): - super().__init__(loader) + def __init__(self, loader, preload): + super().__init__(loader, preload) self.shutdown_flag = multiprocessing.Value("i", 0) diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index 56ab04281..c47af3996 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -29,6 +29,19 @@ __all__ = [ ] +class _stream_helper: + def __init__(self): + self.stream = 1 + + def get_next(self): + out = self.stream + self.stream = self.stream + 1 + return out + + +_sh = _stream_helper() + + def _valid_device(inp): if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): return True diff --git a/imperative/python/megengine/distributed/group.py b/imperative/python/megengine/distributed/group.py index b9b437834..94d725e07 100644 --- a/imperative/python/megengine/distributed/group.py +++ b/imperative/python/megengine/distributed/group.py @@ -12,7 +12,7 @@ from typing import List, Optional, Tuple from mprop import mproperty -from ..device import set_default_device, what_is_xpu +from ..device import _sh, set_default_device, what_is_xpu from ..random import seed from .server import Client, Server @@ -27,7 +27,6 @@ class StaticData: proc_rank = None device = None backend = None - next_stream = None device_type = None machine_ranks = None @@ -43,6 +42,8 @@ class Group: Args: proc_ranks: rank list of the group, the first one is root rank. + + """ def __init__(self, proc_ranks): @@ -55,9 +56,7 @@ class Group: def reset(self, proc_ranks): self.check(proc_ranks) self.proc_ranks = proc_ranks - self.stream = _sd.next_stream - _sd.next_stream += 1 - self.is_single_machine_cache = None + self.stream = _sh.get_next() def check(self, proc_ranks): assert _sd is not None, "please call init_process_group first" @@ -160,7 +159,6 @@ def init_process_group( _sd.proc_rank = rank _sd.device = device _sd.backend = backend - _sd.next_stream = 1 _sd.device_type = device_type WORLD.reset(list(range(world_size))) diff --git a/imperative/python/test/unit/data/test_pre_dataloader.py b/imperative/python/test/unit/data/test_pre_dataloader.py new file mode 100644 index 000000000..75cd197f3 --- /dev/null +++ b/imperative/python/test/unit/data/test_pre_dataloader.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import gc +import os +import platform +import time + +import numpy as np +import pytest + +from megengine.data.collator import Collator +from megengine.data.dataloader import DataLoader +from megengine.data.dataset import ArrayDataset, StreamDataset +from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler +from megengine.data.transform import ( + Compose, + Normalize, + PseudoTransform, + ToMode, + Transform, +) + + +def init_dataset(): + sample_num = 100 + rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) + label = np.random.randint(0, 10, size=(sample_num,), dtype=int) + dataset = ArrayDataset(rand_data, label) + return dataset + + +def test_dataloader_init(): + dataset = init_dataset() + with pytest.raises(ValueError): + dataloader = DataLoader(dataset, num_workers=2, divide=True) + with pytest.raises(ValueError): + dataloader = DataLoader(dataset, num_workers=-1) + with pytest.raises(ValueError): + dataloader = DataLoader(dataset, timeout=-1) + with pytest.raises(ValueError): + dataloader = DataLoader(dataset, num_workers=0, divide=True) + + dataloader = DataLoader(dataset, preload=True) + assert isinstance(dataloader.sampler, SequentialSampler) + assert isinstance(dataloader.transform, PseudoTransform) + assert isinstance(dataloader.collator, Collator) + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=6, drop_last=False), + preload=True, + ) + assert len(dataloader) == 17 + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=6, drop_last=True), + preload=True, + ) + assert len(dataloader) == 16 + + +class MyStream(StreamDataset): + def __init__(self, number, batch=False, error_foramt=False, block=False): + self.number = number + self.batch = batch + self.error_format = error_foramt + self.block = block + + def __iter__(self): + for cnt in range(self.number): + if self.block: + for _ in range(10): + time.sleep(1) + if self.batch: + data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") + yield (True, (data, [cnt, cnt - self.number])) + else: + data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") + if self.error_format: + yield (data, cnt) + else: + yield (False, (data, cnt)) + raise StopIteration + + +@pytest.mark.parametrize("batch", [True, False]) +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_stream_dataloader(batch, num_workers): + dataset = MyStream(100, batch=batch) + sampler = StreamSampler(batch_size=4) + dataloader = DataLoader( + dataset, + sampler, + Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), + num_workers=num_workers, + preload=True, + ) + + check_set = set() + + for step, data in enumerate(dataloader): + if step == 10: + break + assert data[0]._tuple_shape == (4, 3, 2, 2) + assert data[1]._tuple_shape == (4,) + for i in data[1]: + assert i not in check_set + check_set.add(i) + + +def test_stream_dataloader_error(): + dataset = MyStream(100, error_foramt=True) + sampler = StreamSampler(batch_size=4) + dataloader = DataLoader(dataset, sampler, preload=True) + with pytest.raises(AssertionError, match=r".*tuple.*"): + data_iter = iter(dataloader) + next(data_iter) + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_stream_dataloader_timeout(num_workers): + dataset = MyStream(100, False, block=True) + sampler = StreamSampler(batch_size=4) + + dataloader = DataLoader( + dataset, sampler, num_workers=num_workers, timeout=2, preload=True + ) + with pytest.raises(RuntimeError, match=r".*timeout.*"): + data_iter = iter(dataloader) + next(data_iter) + + +def test_dataloader_serial(): + dataset = init_dataset() + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + preload=True, + ) + for (data, label) in dataloader: + assert data._tuple_shape == (4, 1, 32, 32) + assert label._tuple_shape == (4,) + + +def test_dataloader_parallel(): + # set max shared memory to 100M + os.environ["MGE_PLASMA_MEMORY"] = "100000000" + + dataset = init_dataset() + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + divide=False, + preload=True, + ) + for (data, label) in dataloader: + assert data._tuple_shape == (4, 1, 32, 32) + assert label._tuple_shape == (4,) + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + divide=True, + preload=True, + ) + for (data, label) in dataloader: + assert data._tuple_shape == (4, 1, 32, 32) + assert label._tuple_shape == (4,) + + +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +def test_dataloader_parallel_timeout(): + dataset = init_dataset() + + class TimeoutTransform(Transform): + def __init__(self): + pass + + def apply(self, input): + time.sleep(10) + return input + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + transform=TimeoutTransform(), + num_workers=2, + timeout=2, + preload=True, + ) + with pytest.raises(RuntimeError, match=r".*timeout.*"): + data_iter = iter(dataloader) + batch_data = next(data_iter) + + +@pytest.mark.skipif( + platform.system() == "Windows", + reason="dataloader do not support parallel on windows", +) +def test_dataloader_parallel_worker_exception(): + print("in target") + dataset = init_dataset() + + class FakeErrorTransform(Transform): + def __init__(self): + pass + + def apply(self, input): + y = x + 1 + return input + + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + transform=FakeErrorTransform(), + num_workers=2, + preload=True, + ) + with pytest.raises(RuntimeError, match=r"worker.*died"): + data_iter = iter(dataloader) + batch_data = next(data_iter) + + +def _multi_instances_parallel_dataloader_worker(): + dataset = init_dataset() + + for divide_flag in [True, False]: + train_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=4, drop_last=False), + num_workers=2, + divide=divide_flag, + preload=True, + ) + val_dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, batch_size=10, drop_last=False), + num_workers=2, + divide=divide_flag, + preload=True, + ) + for idx, (data, label) in enumerate(train_dataloader): + assert data._tuple_shape == (4, 1, 32, 32) + assert label._tuple_shape == (4,) + if idx % 5 == 0: + for val_data, val_label in val_dataloader: + assert val_data._tuple_shape == (10, 1, 32, 32) + assert val_label._tuple_shape == (10,) + + +def test_dataloader_parallel_multi_instances(): + # set max shared memory to 100M + os.environ["MGE_PLASMA_MEMORY"] = "100000000" + + _multi_instances_parallel_dataloader_worker() + + +@pytest.mark.isolated_distributed +def test_dataloader_parallel_multi_instances_multiprocessing(): + gc.collect() + # set max shared memory to 100M + os.environ["MGE_PLASMA_MEMORY"] = "100000000" + + import multiprocessing as mp + + # mp.set_start_method("spawn") + processes = [] + for i in range(4): + p = mp.Process(target=_multi_instances_parallel_dataloader_worker) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0 + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_timeout_event(num_workers): + def cb(): + return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) + + dataset = MyStream(100, block=True) + sampler = StreamSampler(batch_size=4) + + dataloader = DataLoader( + dataset, + sampler, + num_workers=num_workers, + timeout=2, + timeout_event=cb, + preload=True, + ) + for _, data in enumerate(dataloader): + np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) + np.testing.assert_equal(data[1], np.ones(shape=(4,))) + break -- GitLab