提交 edc92ccf 编写于 作者: M Megvii Engine Team

perf(imperative/data): improve dataloader preformance

GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01
上级 896b0193
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import collections.abc import collections.abc
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import count
from typing import Any, Generator, Iterator, List, Union from typing import Any, Generator, Iterator, List, Union
import numpy as np import numpy as np
...@@ -126,13 +127,15 @@ class MapSampler(Sampler): ...@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if self.world_size > 1: if self.world_size > 1:
indices = self.scatter(indices) indices = self.scatter(indices)
step, length = self.batch_size, len(indices) batch = []
batch_index = [indices[i : i + step] for i in range(0, length, step)] for idx in indices:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if self.drop_last and len(batch_index[-1]) < self.batch_size: if len(batch) > 0 and not self.drop_last:
batch_index.pop() yield batch
return iter(batch_index)
class StreamSampler(Sampler): class StreamSampler(Sampler):
...@@ -151,10 +154,18 @@ class StreamSampler(Sampler): ...@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self.batch_size = batch_size self.batch_size = batch_size
def __iter__(self): def __iter__(self):
return self return self.batch()
def __next__(self): def batch(self):
return iter(range(self.batch_size)) batch = []
for idx in self.sample():
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
def sample(self):
return count(start=0)
class SequentialSampler(MapSampler): class SequentialSampler(MapSampler):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import os import os
import platform import platform
import time import time
...@@ -7,7 +15,7 @@ import numpy as np ...@@ -7,7 +15,7 @@ import numpy as np
import pytest import pytest
from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import ( from megengine.data.transform import (
...@@ -29,14 +37,10 @@ def init_dataset(): ...@@ -29,14 +37,10 @@ def init_dataset():
def test_dataloader_init(): def test_dataloader_init():
dataset = init_dataset() dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1) dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1) dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset) dataloader = DataLoader(dataset)
assert isinstance(dataloader.sampler, SequentialSampler) assert isinstance(dataloader.sampler, SequentialSampler)
...@@ -54,10 +58,8 @@ def test_dataloader_init(): ...@@ -54,10 +58,8 @@ def test_dataloader_init():
class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False): def __init__(self, number, block=False):
self.number = number self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block self.block = block
def __iter__(self): def __iter__(self):
...@@ -65,22 +67,14 @@ class MyStream(StreamDataset): ...@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if self.block: if self.block:
for _ in range(10): for _ in range(10):
time.sleep(1) time.sleep(1)
if self.batch: data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") yield (data, cnt)
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100, batch=batch) dataset = MyStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
...@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers): ...@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
) )
check_set = set() check_set = set()
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
if step == 10: if step == 10:
break break
...@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers): ...@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i) check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers): def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True) dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2) dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
...@@ -140,17 +124,6 @@ def test_dataloader_parallel(): ...@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset, dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2, num_workers=2,
divide=False,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
) )
for (data, label) in dataloader: for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32) assert data.shape == (4, 1, 32, 32)
...@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception(): ...@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform=FakeErrorTransform(), transform=FakeErrorTransform(),
num_workers=2, num_workers=2,
) )
with pytest.raises(RuntimeError, match=r"worker.*died"): with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
batch_data = next(data_iter) batch_data = next(data_iter)
...@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception(): ...@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker(): def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset() dataset = init_dataset()
for divide_flag in [True, False]: train_dataloader = DataLoader(
train_dataloader = DataLoader( dataset,
dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2,
num_workers=2, )
divide=divide_flag, val_dataloader = DataLoader(
) dataset,
val_dataloader = DataLoader( sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
dataset, num_workers=2,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), )
num_workers=2, for idx, (data, label) in enumerate(train_dataloader):
divide=divide_flag, assert data.shape == (4, 1, 32, 32)
) assert label.shape == (4,)
for idx, (data, label) in enumerate(train_dataloader): if idx % 5 == 0:
assert data.shape == (4, 1, 32, 32) for val_data, val_label in val_dataloader:
assert label.shape == (4,) assert val_data.shape == (10, 1, 32, 32)
if idx % 5 == 0: assert val_label.shape == (10,)
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
...@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): ...@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert p.exitcode == 0 assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2]) def partition(ls, size):
def test_timeout_event(num_workers): return [ls[i : i + size] for i in range(0, len(ls), size)]
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
class MyPreStream(StreamDataset):
def __init__(self, number, block=False):
self.number = [i for i in range(number)]
self.block = block
self.data = []
for i in range(100):
self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8"))
def __iter__(self):
worker_info = get_worker_info()
per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker)))
pre_data = iter(partition(self.data, per_worker)[worker_info.idx])
pre_cnt = partition(self.number, per_worker)[worker_info.idx]
for cnt in pre_cnt:
if self.block:
for _ in range(10):
time.sleep(1)
yield (next(pre_data), cnt)
raise StopIteration
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_prestream_dataloader_multiprocessing():
dataset = MyPreStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=2,
parallel_stream=True,
)
check_set = set()
for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)
@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_predataloader_parallel_worker_exception():
dataset = MyPreStream(100)
class FakeErrorTransform(Transform):
def __init__(self):
pass
def apply(self, input):
raise RuntimeError("test raise error")
return input
dataloader = DataLoader( dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb dataset,
sampler=StreamSampler(batch_size=4),
transform=FakeErrorTransform(),
num_workers=2,
parallel_stream=True,
) )
for _, data in enumerate(dataloader): with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) data_iter = iter(dataloader)
np.testing.assert_equal(data[1], np.ones(shape=(4,))) batch_data = next(data_iter)
break print(batch_data.shape)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc import gc
import math
import os import os
import platform import platform
import time import time
...@@ -8,7 +16,7 @@ import numpy as np ...@@ -8,7 +16,7 @@ import numpy as np
import pytest import pytest
from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import ( from megengine.data.transform import (
...@@ -30,14 +38,10 @@ def init_dataset(): ...@@ -30,14 +38,10 @@ def init_dataset():
def test_dataloader_init(): def test_dataloader_init():
dataset = init_dataset() dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1) dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1) dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)
dataloader = DataLoader(dataset, preload=True) dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler) assert isinstance(dataloader.sampler, SequentialSampler)
...@@ -59,10 +63,8 @@ def test_dataloader_init(): ...@@ -59,10 +63,8 @@ def test_dataloader_init():
class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False): def __init__(self, number, block=False):
self.number = number self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block self.block = block
def __iter__(self): def __iter__(self):
...@@ -70,22 +72,14 @@ class MyStream(StreamDataset): ...@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if self.block: if self.block:
for _ in range(10): for _ in range(10):
time.sleep(1) time.sleep(1)
if self.batch: data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") yield (data, cnt)
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
raise StopIteration raise StopIteration
@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers): def test_stream_dataloader(num_workers):
dataset = MyStream(100, batch=batch) dataset = MyStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
...@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers): ...@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i) check_set.add(i)
def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers): def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True) dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
...@@ -150,18 +135,6 @@ def test_dataloader_parallel(): ...@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset, dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2, num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True, preload=True,
) )
for (data, label) in dataloader: for (data, label) in dataloader:
...@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception(): ...@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers=2, num_workers=2,
preload=True, preload=True,
) )
with pytest.raises(RuntimeError, match=r"worker.*died"): with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
batch_data = next(data_iter) batch_data = next(data_iter)
...@@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception(): ...@@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker(): def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset() dataset = init_dataset()
for divide_flag in [True, False]: train_dataloader = DataLoader(
train_dataloader = DataLoader( dataset,
dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), num_workers=2,
num_workers=2, preload=True,
divide=divide_flag, )
preload=True, val_dataloader = DataLoader(
) dataset,
val_dataloader = DataLoader( sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
dataset, num_workers=2,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False), preload=True,
num_workers=2, )
divide=divide_flag, for idx, (data, label) in enumerate(train_dataloader):
preload=True, assert data._tuple_shape == (4, 1, 32, 32)
) assert label._tuple_shape == (4,)
for idx, (data, label) in enumerate(train_dataloader): if idx % 5 == 0:
assert data._tuple_shape == (4, 1, 32, 32) for val_data, val_label in val_dataloader:
assert label._tuple_shape == (4,) assert val_data._tuple_shape == (10, 1, 32, 32)
if idx % 5 == 0: assert val_label._tuple_shape == (10,)
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)
def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
...@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing(): ...@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for p in processes: for p in processes:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0
@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册