提交 dff7719e 编写于 作者: M Megvii Engine Team

feat(mge/distributed): add preload host data with op

fix(mge/distributed): change api name with preload

fix(mge/distributed): fix recursive model in preload tensor

fix(mge/distributed): fix recursive when cache contain None

GitOrigin-RevId: 80e2a6dd7010f48a279b0d0be8ae6d5aad53f011
上级 8b94f493
......@@ -15,12 +15,15 @@ import queue
import random
import threading
import time
from typing import Callable
from typing import Callable, Union
import numpy as np
from ..device import _sh, get_default_device
from ..functional.tensor import copy
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from ..tensor import Tensor
from .collator import Collator
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
......@@ -44,7 +47,7 @@ def raise_timeout_error():
class DataLoader:
r"""Provides a convenient way to iterate on a given dataset.
DataLoader combines a dataset with
:class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`,
make it flexible to get minibatch continually from a dataset.
......@@ -66,6 +69,8 @@ class DataLoader:
``True`` means one batch is divided into :attr:`num_workers` pieces, and
the workers will process these pieces parallelly. ``False`` means
different sub-process will process different batch. Default: False
preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False
the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported.
"""
__initialized = False
......@@ -79,6 +84,7 @@ class DataLoader:
timeout: int = 0,
timeout_event: Callable = raise_timeout_error,
divide: bool = False,
preload: bool = False,
):
if num_workers < 0:
raise ValueError("num_workers should not be negative")
......@@ -96,6 +102,7 @@ class DataLoader:
self.timeout_event = timeout_event
self.divide = divide
self.preload = preload
if isinstance(dataset, StreamDataset):
self.sampler = sampler if sampler else StreamSampler(batch_size=1)
......@@ -145,24 +152,74 @@ class DataLoader:
self.num_workers = 0
if isinstance(self.dataset, StreamDataset):
if not self.num_workers:
return _SerialStreamDataLoaderIter(self)
return _SerialStreamDataLoaderIter(self, self.preload)
else:
return _ParallelStreamDataLoaderIter(self)
return _ParallelStreamDataLoaderIter(self, self.preload)
else:
assert isinstance(
self.dataset, Dataset
), "Can not recognize this kind of dataset: %s" % type(self.dataset)
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
return _SerialMapDataLoaderIter(self, self.preload)
else:
return _ParallelMapDataLoaderIter(self)
return _ParallelMapDataLoaderIter(self, self.preload)
def __len__(self):
return len(self.sampler)
class _BaseMapDataLoaderIter:
def __init__(self, loader):
class PreLoader:
def __init__(self, preload):
if preload:
self.default_device = get_default_device()
self.pre_load_device = self.default_device + ":" + str(_sh.get_next())
self.pre_load_device_cache = None
self.preload = preload
"""
strategy one: load from numpy data, and generate dtype tensor
"""
def _load_tensor(self, batch, cached=True):
if isinstance(batch, np.ndarray):
device = self.pre_load_device if cached else self.default_device
return Tensor(batch, device=device)
elif isinstance(batch, collections.abc.Mapping):
return {k: self._load_tensor(v, cached) for k, v in batch.items()}
elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
return type(batch)(*(self._load_tensor(value, cached) for value in batch))
elif isinstance(batch, collections.abc.Sequence):
return [self._load_tensor(value, cached) for value in batch]
else:
return batch
"""
strategy two: load from cache that is already tensor just do d2d copy
"""
def _load_cache(self, data):
if isinstance(data, Tensor):
if data.device == self.default_device:
return data
return copy(data, device=self.default_device)
elif isinstance(data, collections.abc.Mapping):
return {k: self._load_cache(v) for k, v in data.items()}
elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple
return type(data)(*(self._load_cache(value) for value in data))
elif isinstance(data, collections.abc.Sequence):
return [self._load_cache(value) for value in data]
else:
return data
def _swap_out_cache(self):
out = self._load_cache(self.pre_load_device_cache)
self.pre_load_device_cache = None # clean cache
return out
class _BaseMapDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.seed = _random_seed_generator().__next__()
......@@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter:
return self
def __next__(self):
if self.preload:
cached = self.pre_load_device_cache
if cached is None: # first and last
if self.num_processed >= len(self): # last
raise StopIteration
elif self.num_processed == 0: # first
self._try_load_tensor(cached=False) # first do the h2d
out = self._swap_out_cache()
self._try_load_tensor()
return out
else:
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch
def _try_load_tensor(self, cached=True):
if self.num_processed >= len(self):
raise StopIteration
minibatch = self._get_next_batch()
self.num_processed += 1
return minibatch
return
else:
self.num_processed += 1
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader):
super(_SerialMapDataLoaderIter, self).__init__(loader)
def __init__(self, loader, preload):
super(_SerialMapDataLoaderIter, self).__init__(loader, preload)
self.indices_iter = iter(self.sampler)
def _get_next_batch(self):
......@@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False
def __init__(self, loader):
super(_ParallelMapDataLoaderIter, self).__init__(loader)
def __init__(self, loader, preload):
super(_ParallelMapDataLoaderIter, self).__init__(loader, preload)
self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
......@@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
self._shutdown()
class _BaseStreamDataLoaderIter:
def __init__(self, loader):
class _BaseStreamDataLoaderIter(PreLoader):
def __init__(self, loader, preload):
super().__init__(preload)
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
......@@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter:
return self
def __next__(self):
return self._get_next_batch()
if self.preload:
if self.pre_load_device_cache is None:
self._try_load_tensor(cached=False) # load in current
out = self._swap_out_cache()
self._try_load_tensor() # load in cached
return out
else:
return self._get_next_batch()
def _try_load_tensor(self, cached=True):
batch = self._get_next_batch()
self.pre_load_device_cache = self._load_tensor(batch, cached)
class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.dataset_iter = iter(self.dataset)
self.idx = 0
self.unused = []
......@@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False
def __init__(self, loader):
super().__init__(loader)
def __init__(self, loader, preload):
super().__init__(loader, preload)
self.shutdown_flag = multiprocessing.Value("i", 0)
......
......@@ -29,6 +29,19 @@ __all__ = [
]
class _stream_helper:
def __init__(self):
self.stream = 1
def get_next(self):
out = self.stream
self.stream = self.stream + 1
return out
_sh = _stream_helper()
def _valid_device(inp):
if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp):
return True
......
......@@ -12,7 +12,7 @@ from typing import List, Optional, Tuple
from mprop import mproperty
from ..device import set_default_device, what_is_xpu
from ..device import _sh, set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server
......@@ -27,7 +27,6 @@ class StaticData:
proc_rank = None
device = None
backend = None
next_stream = None
device_type = None
machine_ranks = None
......@@ -43,6 +42,8 @@ class Group:
Args:
proc_ranks: rank list of the group, the first one is root rank.
"""
def __init__(self, proc_ranks):
......@@ -55,9 +56,7 @@ class Group:
def reset(self, proc_ranks):
self.check(proc_ranks)
self.proc_ranks = proc_ranks
self.stream = _sd.next_stream
_sd.next_stream += 1
self.is_single_machine_cache = None
self.stream = _sh.get_next()
def check(self, proc_ranks):
assert _sd is not None, "please call init_process_group first"
......@@ -160,7 +159,6 @@ def init_process_group(
_sd.proc_rank = rank
_sd.device = device
_sd.backend = backend
_sd.next_stream = 1
_sd.device_type = device_type
WORLD.reset(list(range(world_size)))
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import os
import platform
import time
import numpy as np
import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
Compose,
Normalize,
PseudoTransform,
ToMode,
Transform,
)
def init_dataset():
sample_num = 100
rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
return dataset
def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler)
assert isinstance(dataloader.transform, PseudoTransform)
assert isinstance(dataloader.collator, Collator)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=6, drop_last=False),
preload=True,
)
assert len(dataloader) == 17
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=6, drop_last=True),
preload=True,
)
assert len(dataloader) == 16
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block
def __iter__(self):
for cnt in range(self.number):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=num_workers,
preload=True,
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0]._tuple_shape == (4, 3, 2, 2)
assert data[1]._tuple_shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, preload=True
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
next(data_iter)
def test_dataloader_serial():
dataset = init_dataset()
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
def test_dataloader_parallel():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
dataset = init_dataset()
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_timeout():
dataset = init_dataset()
class TimeoutTransform(Transform):
def __init__(self):
pass
def apply(self, input):
time.sleep(10)
return input
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=TimeoutTransform(),
num_workers=2,
timeout=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r".*timeout.*"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_dataloader_parallel_worker_exception():
print("in target")
dataset = init_dataset()
class FakeErrorTransform(Transform):
def __init__(self):
pass
def apply(self, input):
y = x + 1
return input
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
transform=FakeErrorTransform(),
num_workers=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)
def test_dataloader_parallel_multi_instances():
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
_multi_instances_parallel_dataloader_worker()
@pytest.mark.isolated_distributed
def test_dataloader_parallel_multi_instances_multiprocessing():
gc.collect()
# set max shared memory to 100M
os.environ["MGE_PLASMA_MEMORY"] = "100000000"
import multiprocessing as mp
# mp.set_start_method("spawn")
processes = []
for i in range(4):
p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
p.start()
processes.append(p)
for p in processes:
p.join()
assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册