提交 edc92ccf 编写于 作者: M Megvii Engine Team

perf(imperative/data): improve dataloader preformance

GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01
上级 896b0193
......@@ -2,6 +2,7 @@
import collections.abc
import math
from abc import ABC, abstractmethod
from itertools import count
from typing import Any, Generator, Iterator, List, Union
import numpy as np
......@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if self.world_size > 1:
indices = self.scatter(indices)
step, length = self.batch_size, len(indices)
batch_index = [indices[i : i + step] for i in range(0, length, step)]
batch = []
for idx in indices:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if self.drop_last and len(batch_index[-1]) < self.batch_size:
batch_index.pop()
return iter(batch_index)
if len(batch) > 0 and not self.drop_last:
yield batch
class StreamSampler(Sampler):
......@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self.batch_size = batch_size
def __iter__(self):
return self
return self.batch()
def __next__(self):
return iter(range(self.batch_size))
def batch(self):
batch = []
for idx in self.sample():
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
def sample(self):
return count(start=0)
class SequentialSampler(MapSampler):
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import os
import platform
import time
......@@ -7,7 +15,7 @@ import numpy as np
import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
......@@ -29,14 +37,10 @@ def init_dataset():
def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset)
assert isinstance(dataloader.sampler, SequentialSampler)
......@@ -54,10 +58,8 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block
def __iter__(self):
......@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
yield (data, cnt)
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
......@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
......@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
......@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
......@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform=FakeErrorTransform(),
num_workers=2,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
......@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)
def test_dataloader_parallel_multi_instances():
......@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
def partition(ls, size):
return [ls[i : i + size] for i in range(0, len(ls), size)]
dataset = MyStream(100, block=True)
class MyPreStream(StreamDataset):
def __init__(self, number, block=False):
self.number = [i for i in range(number)]
self.block = block
self.data = []
for i in range(100):
self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8"))
def __iter__(self):
worker_info = get_worker_info()
per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker)))
pre_data = iter(partition(self.data, per_worker)[worker_info.idx])
pre_cnt = partition(self.number, per_worker)[worker_info.idx]
for cnt in pre_cnt:
if self.block:
for _ in range(10):
time.sleep(1)
yield (next(pre_data), cnt)
raise StopIteration
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_prestream_dataloader_multiprocessing():
dataset = MyPreStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=2,
parallel_stream=True,
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_predataloader_parallel_worker_exception():
dataset = MyPreStream(100)
class FakeErrorTransform(Transform):
def __init__(self):
pass
def apply(self, input):
raise RuntimeError("test raise error")
return input
dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb
dataset,
sampler=StreamSampler(batch_size=4),
transform=FakeErrorTransform(),
num_workers=2,
parallel_stream=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
print(batch_data.shape)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc
import math
import os
import platform
import time
......@@ -8,7 +16,7 @@ import numpy as np
import pytest
from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
......@@ -30,14 +38,10 @@ def init_dataset():
def test_dataloader_init():
dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler)
......@@ -59,10 +63,8 @@ def test_dataloader_init():
class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block
def __iter__(self):
......@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if self.block:
for _ in range(10):
time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
yield (data, cnt)
raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
......@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
......@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True,
)
for (data, label) in dataloader:
......@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers=2,
preload=True,
)
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
......@@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset()
for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)
def test_dataloader_parallel_multi_instances():
......@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for p in processes:
p.join()
assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册