test_dataloader.py 11.7 KB
Newer Older
1
# -*- coding: utf-8 -*-
2 3 4 5 6 7 8 9
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
10
import multiprocessing
11
import os
12
import platform
13 14 15 16 17 18
import time

import numpy as np
import pytest

from megengine.data.collator import Collator
19
from megengine.data.dataloader import DataLoader, get_worker_info
20 21 22 23 24 25 26 27 28
from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import (
    Compose,
    Normalize,
    PseudoTransform,
    ToMode,
    Transform,
)
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


def init_dataset():
    sample_num = 100
    rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
    label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
    dataset = ArrayDataset(rand_data, label)
    return dataset


def test_dataloader_init():
    dataset = init_dataset()
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, num_workers=-1)
    with pytest.raises(ValueError):
        dataloader = DataLoader(dataset, timeout=-1)

    dataloader = DataLoader(dataset)
    assert isinstance(dataloader.sampler, SequentialSampler)
    assert isinstance(dataloader.transform, PseudoTransform)
    assert isinstance(dataloader.collator, Collator)

    dataloader = DataLoader(
        dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=False)
    )
    assert len(dataloader) == 17
    dataloader = DataLoader(
        dataset, sampler=RandomSampler(dataset, batch_size=6, drop_last=True)
    )
    assert len(dataloader) == 16


61
class MyStream(StreamDataset):
62
    def __init__(self, number, block=False):
63
        self.number = number
64
        self.block = block
65 66 67

    def __iter__(self):
        for cnt in range(self.number):
68 69 70
            if self.block:
                for _ in range(10):
                    time.sleep(1)
71 72
            data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
            yield (data, cnt)
73 74 75
        raise StopIteration


76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
@pytest.mark.skipif(
    multiprocessing.get_start_method() != "fork",
    reason="the runtime error is only raised when fork",
)
def test_dataloader_worker_signal_exception():
    dataset = init_dataset()

    class FakeErrorTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            pid = os.getpid()
            subprocess.run(["kill", "-11", str(pid)])
            return input

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=FakeErrorTransform(),
        num_workers=2,
    )
    with pytest.raises(RuntimeError, match=r"DataLoader worker.* exited unexpectedly"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)


class IndexErrorTransform(Transform):
    def __init__(self):
        self.array = [0, 1, 2]

    def apply(self, input):
        error_item = self.array[3]
        return input


class TypeErrorTransform(Transform):
    def __init__(self):
        self.adda = 1
        self.addb = "2"

    def apply(self, input):
        error_item = self.adda + self.addb
        return input


@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
@pytest.mark.parametrize("transform", [IndexErrorTransform(), TypeErrorTransform()])
def test_dataloader_worker_baseerror(transform):
    dataset = init_dataset()

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=transform,
        num_workers=2,
    )
    with pytest.raises(RuntimeError, match=r"Caught .*Error in DataLoader worker"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)


@pytest.mark.skipif(
    np.__version__ >= "1.20.0",
    reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
149
@pytest.mark.parametrize("num_workers", [0, 2])
150 151
def test_stream_dataloader(num_workers):
    dataset = MyStream(100)
152 153 154 155 156 157 158 159 160 161 162 163
    sampler = StreamSampler(batch_size=4)
    dataloader = DataLoader(
        dataset,
        sampler,
        Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
        num_workers=num_workers,
    )

    check_set = set()
    for step, data in enumerate(dataloader):
        if step == 10:
            break
164
        assert data[0].shape == (4, 3, 2, 2)
165 166 167 168 169 170 171 172
        assert data[1].shape == (4,)
        for i in data[1]:
            assert i not in check_set
            check_set.add(i)


@pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers):
173
    dataset = MyStream(100, block=True)
174 175
    sampler = StreamSampler(batch_size=4)

176
    dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
177 178 179 180 181
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        next(data_iter)


182 183 184 185 186 187 188 189 190 191
def test_dataloader_serial():
    dataset = init_dataset()
    dataloader = DataLoader(
        dataset, sampler=RandomSampler(dataset, batch_size=4, drop_last=False)
    )
    for (data, label) in dataloader:
        assert data.shape == (4, 1, 32, 32)
        assert label.shape == (4,)


192 193 194 195
@pytest.mark.skipif(
    np.__version__ >= "1.20.0",
    reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
def test_dataloader_parallel():
    # set max shared memory to 100M
    os.environ["MGE_PLASMA_MEMORY"] = "100000000"

    dataset = init_dataset()
    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        num_workers=2,
    )
    for (data, label) in dataloader:
        assert data.shape == (4, 1, 32, 32)
        assert label.shape == (4,)


211 212 213 214
@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
215 216 217 218
@pytest.mark.skipif(
    multiprocessing.get_start_method() != "fork",
    reason="the runtime error is only raised when fork",
)
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
def test_dataloader_parallel_timeout():
    dataset = init_dataset()

    class TimeoutTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            time.sleep(10)
            return input

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=TimeoutTransform(),
        num_workers=2,
        timeout=2,
    )
    with pytest.raises(RuntimeError, match=r".*timeout.*"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)


242 243 244 245
@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
246 247 248 249
@pytest.mark.skipif(
    multiprocessing.get_start_method() != "fork",
    reason="the runtime error is only raised when fork",
)
250 251 252 253 254 255 256 257
def test_dataloader_parallel_worker_exception():
    dataset = init_dataset()

    class FakeErrorTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
258
            raise RuntimeError("test raise error")
259 260 261 262 263 264 265 266
            return input

    dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        transform=FakeErrorTransform(),
        num_workers=2,
    )
267
    with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
268 269 270 271 272 273 274
        data_iter = iter(dataloader)
        batch_data = next(data_iter)


def _multi_instances_parallel_dataloader_worker():
    dataset = init_dataset()

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
    train_dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
        num_workers=2,
    )
    val_dataloader = DataLoader(
        dataset,
        sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
        num_workers=2,
    )
    for idx, (data, label) in enumerate(train_dataloader):
        assert data.shape == (4, 1, 32, 32)
        assert label.shape == (4,)
        if idx % 5 == 0:
            for val_data, val_label in val_dataloader:
                assert val_data.shape == (10, 1, 32, 32)
                assert val_label.shape == (10,)
292 293


294 295 296 297
@pytest.mark.skipif(
    np.__version__ >= "1.20.0",
    reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
298 299 300 301 302 303 304
def test_dataloader_parallel_multi_instances():
    # set max shared memory to 100M
    os.environ["MGE_PLASMA_MEMORY"] = "100000000"

    _multi_instances_parallel_dataloader_worker()


305
@pytest.mark.isolated_distributed
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
def test_dataloader_parallel_multi_instances_multiprocessing():
    # set max shared memory to 100M
    os.environ["MGE_PLASMA_MEMORY"] = "100000000"

    import multiprocessing as mp

    # mp.set_start_method("spawn")
    processes = []
    for i in range(4):
        p = mp.Process(target=_multi_instances_parallel_dataloader_worker)
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
321
        assert p.exitcode == 0
322 323


324 325
def partition(ls, size):
    return [ls[i : i + size] for i in range(0, len(ls), size)]
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348

class MyPreStream(StreamDataset):
    def __init__(self, number, block=False):
        self.number = [i for i in range(number)]
        self.block = block
        self.data = []
        for i in range(100):
            self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8"))

    def __iter__(self):
        worker_info = get_worker_info()
        per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker)))
        pre_data = iter(partition(self.data, per_worker)[worker_info.idx])
        pre_cnt = partition(self.number, per_worker)[worker_info.idx]
        for cnt in pre_cnt:
            if self.block:
                for _ in range(10):
                    time.sleep(1)
            yield (next(pre_data), cnt)
        raise StopIteration


349 350 351 352
@pytest.mark.skipif(
    np.__version__ >= "1.20.0",
    reason="pyarrow is incompatible with numpy vserion 1.20.0",
)
353 354 355 356 357 358
@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
def test_prestream_dataloader_multiprocessing():
    dataset = MyPreStream(100)
359
    sampler = StreamSampler(batch_size=4)
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
    dataloader = DataLoader(
        dataset,
        sampler,
        Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
        num_workers=2,
        parallel_stream=True,
    )

    check_set = set()

    for step, data in enumerate(dataloader):
        if step == 10:
            break
        assert data[0].shape == (4, 3, 2, 2)
        assert data[1].shape == (4,)
        for i in data[1]:
            assert i not in check_set
            check_set.add(i)


@pytest.mark.skipif(
    platform.system() == "Windows",
    reason="dataloader do not support parallel on windows",
)
384 385 386 387
@pytest.mark.skipif(
    multiprocessing.get_start_method() != "fork",
    reason="the runtime error is only raised when fork",
)
388 389 390 391 392 393 394 395 396 397
def test_predataloader_parallel_worker_exception():
    dataset = MyPreStream(100)

    class FakeErrorTransform(Transform):
        def __init__(self):
            pass

        def apply(self, input):
            raise RuntimeError("test raise error")
            return input
398 399

    dataloader = DataLoader(
400 401 402 403 404
        dataset,
        sampler=StreamSampler(batch_size=4),
        transform=FakeErrorTransform(),
        num_workers=2,
        parallel_stream=True,
405
    )
406 407 408 409
    with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
        data_iter = iter(dataloader)
        batch_data = next(data_iter)
        print(batch_data.shape)