dataloader_iter.py 33.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import signal
19
import numbers
20 21 22
import logging
import itertools
import threading
23
import warnings
24
import numpy as np
25
from collections import namedtuple
26 27 28 29 30
from paddle.fluid.framework import (
    _set_expected_place,
    _current_expected_place,
    set_flags,
)
31

T
tianshuo78520a 已提交
32
import queue
33

34
import paddle
C
chenjian 已提交
35
import paddle.profiler as profiler
36
from paddle.profiler.utils import in_profiler_mode
37
from .. import core, layers
姜永久 已提交
38
from ..framework import in_dygraph_mode
39 40 41 42 43
from ..multiprocess_utils import (
    _set_SIGCHLD_handler,
    MP_STATUS_CHECK_INTERVAL,
    CleanupFuncRegistrar,
)
44
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
45
from .batch_sampler import _InfiniteIterableSampler
46
from .collate import default_collate_fn, default_convert_fn
47 48 49 50 51 52 53 54 55
from .worker import (
    ParentWatchDog,
    get_worker_info,
    _worker_loop,
    _DatasetKind,
    _IterableDatasetStopIteration,
    _WorkerException,
    _ResumeIteration,
)
56
from .flat import _flatten_batch, _restore_batch
Z
Zhang Ting 已提交
57
from paddle.profiler.timer import benchmark
58 59

__all__ = ['get_worker_info']
60

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
# NOTE: fix `terminate called without an active exception`
# if for loop break and program exit immediately(with no model
# layers processing) after iterate **the first few data** in
# distributed lauch mode, distributed launch will call
# terminate() to kill main process on each devices, but thread
# is still iterating to fullfill blocking queue caches, which
# may cause thread error `terminate called without an active
# exception` for terminate is a strong singal and `__del__`
# of DataLoader may not be called, so we add a global link to
# the last DataLoader instance to call `__del__` to clean up
# resources
# NOTE: cannot simply as `__del__` to CleanupFuncRegistrar,
# for this will remain a link to each DataLoader instance in
# global, and will precludes GC to auto collect DataLoader
# instance and will cause memory leak
_loader = None


def _clear_loader():
    global _loader
    if _loader is not None:
        try:
            _loader.__del__()
            del _loader
        except:
            pass


CleanupFuncRegistrar.register(_clear_loader)

91

92
class _DataLoaderIterBase:
93 94 95 96 97 98 99 100 101 102 103 104 105 106
    """
    Iterator implement of DataLoader, will load and feed mini-batch
    data by setting in given dataloader.

    Args:
        loader(instance of DataLoader): instance of `fluid.io.DataLoader`
    """

    def __init__(self, loader):
        self._dataset = loader.dataset
        self._feed_list = loader.feed_list or []
        self._places = loader.places
        self._return_list = loader.return_list
        self._batch_sampler = loader.batch_sampler
107
        self._drop_last = loader.drop_last
108
        self._auto_collate_batch = loader.auto_collate_batch
109 110
        self._num_workers = loader.num_workers
        self._use_buffer_reader = loader.use_buffer_reader
111
        self._prefetch_factor = loader.prefetch_factor
112
        self._use_shared_memory = loader.use_shared_memory
113 114 115
        self._timeout = (
            loader.timeout if loader.timeout > 0 else MP_STATUS_CHECK_INTERVAL
        )
116
        self._worker_init_fn = loader.worker_init_fn
117
        self._dataset_kind = loader.dataset_kind
118
        self._pin_memory = loader.pin_memory
119

K
Kaipeng Deng 已提交
120
        self._sampler_iter = iter(self._index_sampler)
121 122 123
        if self._auto_collate_batch:
            self._collate_fn = loader.collate_fn or default_collate_fn
        else:
124
            self._collate_fn = loader.collate_fn or default_convert_fn
125

126 127 128 129 130 131 132 133 134
        # LoDTensorBlockingQueue instance for create_py_reader and a thread
        # to put mini-batch data to self._blocking_queue, mini-batch data
        # will be get from:
        # 1. multi-process mode: get data from workers' result queue
        # 2. single-process mode: read mini-batch data in main process
        self._blocking_queue = None
        self._thread = None
        self._thread_done_event = threading.Event()

K
Kaipeng Deng 已提交
135 136 137 138 139 140 141 142 143 144
    @property
    def _index_sampler(self):
        if self._auto_collate_batch:
            return self._batch_sampler
        else:
            if self._dataset_kind == _DatasetKind.MAP:
                return list(range(len(self._dataset)))
            else:
                return _InfiniteIterableSampler(self._dataset, 1)

145 146 147 148 149 150
    def __iter__(self):
        return self

    def __len__(self):
        return len(self._batch_sampler)

151 152 153 154 155 156 157 158 159 160
    def _exit_thread_expectedly(self):
        self._thread_done_event.set()
        if self._blocking_queue:
            self._blocking_queue.close()

    def _exit_thread_unexpectedly(self):
        self._thread_done_event.set()
        if self._blocking_queue:
            self._blocking_queue.kill()

161 162 163 164 165 166 167 168

class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
    """
    Single process implement of DataLoaderIter, loading data from
    loader.data in main process
    """

    def __init__(self, loader):
169
        super().__init__(loader)
170

171
        self._dataset_fetcher = _DatasetKind.create_fetcher(
172 173 174 175 176 177
            self._dataset_kind,
            self._dataset,
            self._auto_collate_batch,
            self._collate_fn,
            self._drop_last,
        )
178

179 180 181 182 183 184 185 186
        # NOTE: _structrue_infos used to record the data structure of
        # batch to restore batch structure after reading Tensor
        # from blocking_queue in single-process mode. Note that
        # only single process is used in single-process mode, we
        # can record the data structure sequencely in a list without
        # recording the send and recv index
        self._structure_infos = []

187
        # NOTE: len(self._places) batch data compose as an output
188
        # iteration, set blocking_queue can cache "self._prefetch_factor" iteration datas
189
        # at most here
190
        self._blocking_queue_capacity = self._prefetch_factor * len(
191 192
            self._places
        )
193 194

        self._init_thread()
195 196 197 198
        self._shutdown = False

        global _loader
        _loader = self
199 200 201 202 203 204 205 206

    def _init_thread(self):
        self._var_names = [v.name for v in self._feed_list]
        self._shapes = [v.shape for v in self._feed_list]
        self._dtypes = [v.dtype for v in self._feed_list]
        self._need_check_feed = [
            v.desc.need_check_feed() for v in self._feed_list
        ]
207
        # if only 1 place, do not need to keep order
208
        self._blocking_queue = core.init_lod_tensor_blocking_queue(
209 210 211 212
            core.Variable(),
            self._blocking_queue_capacity,
            len(self._places) > 1,
        )
213
        self._reader = core.create_py_reader(
214 215 216 217 218 219 220 221 222 223 224 225 226 227
            self._blocking_queue,
            self._var_names,
            self._shapes,
            self._dtypes,
            self._need_check_feed,
            self._places,
            self._use_buffer_reader,
            True,
            self._pin_memory,
        )

        self._thread = threading.Thread(
            target=self._thread_loop, args=(_current_expected_place(),)
        )
228 229 230
        self._thread.daemon = True
        self._thread.start()

231
    def _thread_loop(self, legacy_expected_place):
232
        # NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
233 234
        # and it will call platform::SetDeviceId() in c++ internally.
        # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
235
        # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
236
        # APIs in this thread.
L
Leo Chen 已提交
237
        core.set_current_thread_name("Dataloader_" + str(id(self)))
238 239 240 241 242 243 244 245
        _set_expected_place(legacy_expected_place)

        while not self._thread_done_event.is_set():
            try:
                indices = next(self._sampler_iter)

                # read data from dataset in mini-batch
                # with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()):
246
                # read data from dataset in mini-batch
247 248 249
                batch = self._dataset_fetcher.fetch(
                    indices, self._thread_done_event
                )
250 251 252 253
            except StopIteration:
                self._exit_thread_expectedly()
                return

254 255
            if batch is None or self._thread_done_event.is_set():
                break
256 257 258 259

            # flat batch and record structure infos
            batch, structure = _flatten_batch(batch)
            self._structure_infos.append(structure)
260

261 262
            if self._thread_done_event.is_set():
                break
263

264
            try:
265 266 267
                # pack as LoDTensorArray
                array = core.LoDTensorArray()
                for slot in batch:
W
wanghuancoder 已提交
268
                    if isinstance(slot, (paddle.Tensor, core.eager.Tensor)):
K
Kaipeng Deng 已提交
269 270
                        slot = slot.value().get_tensor()
                    elif not isinstance(slot, core.LoDTensor):
271 272 273 274 275 276
                        tmp = core.LoDTensor()
                        tmp.set(slot, core.CPUPlace())
                        slot = tmp

                    array.append(slot)

277 278
                if self._thread_done_event.is_set():
                    break
279

280 281 282 283
                try:
                    self._blocking_queue.push(array)
                except:
                    self._exit_thread_expectedly()
284

285
            except Exception as e:
286
                self._exit_thread_unexpectedly()
287
                raise e
288 289

        self._exit_thread_expectedly()
290 291

    def __next__(self):
292 293 294
        if in_profiler_mode():
            trace_event = profiler.RecordEvent(
                name="_DataLoaderIterSingleProcess",
295 296
                event_type=profiler.TracerEventType.Dataloader,
            )
297
            trace_event.begin()
298
        try:
Z
Zhang Ting 已提交
299 300
            benchmark().check_if_need_record(self)
            benchmark().before_reader()
301
            if in_dygraph_mode():
J
Jiabin Yang 已提交
302
                data = core.eager.read_next_tensor_list(
303 304
                    self._reader.read_next_list()[0]
                )
305
                data = _restore_batch(data, self._structure_infos.pop(0))
306
            else:
307
                # in static graph mode
姜永久 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
                if self._return_list:
                    data = self._reader.read_next_list()
                    for i in range(len(data)):
                        data[i] = data[i]._move_to_list()
                    structs = [
                        self._structure_infos.pop(0)
                        for _ in range(len(self._places))
                    ]
                    data = [_restore_batch(d, s) for d, s in zip(data, structs)]
                    # static graph organized data on multi-device with list, if
                    # place number is 1, there is only 1 device, extra the data
                    # from list for devices to be compatible with dygraph mode
                    if len(self._places) == 1:
                        data = data[0]
                else:
                    data = self._reader.read_next()
Z
Zhang Ting 已提交
324
            benchmark().after_reader()
325 326

            return data
327
        except StopIteration:
328
            self._reader.shutdown()
329
            self._try_shutdown_all()
330
            raise
C
chenjian 已提交
331
        finally:
332 333
            if in_profiler_mode():
                trace_event.end()
334

335 336 337
    def _shutdown_thread(self):
        if self._thread:
            self._thread_done_event.set()
338 339 340 341 342 343 344 345 346 347 348
            # NOTE: we wait for _thread exit for 3 seconds, if
            #       thread not exit normally, force kill it
            for _ in range(3):
                if self._thread.is_alive():
                    time.sleep(1)
                else:
                    break
            else:
                if self._thread is not threading.current_thread():
                    self._thread.join()

349
            self._thread = None
350

351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    def _try_shutdown_all(self):
        if not self._shutdown:
            try:
                # # _blocking_queue in keep order mode holds sub-threads
                # # need to release thread resources on unexpected exit
                if self._blocking_queue:
                    self._blocking_queue.close()
                    self._blocking_queue = None
                # NOTE: blocking queue should be closed firstly for
                # blocking queue read may hang and _thread_done_event
                # cannot be checked
                self._shutdown_thread()
            finally:
                self._shutdown = True

366
    def __del__(self):
367
        self._try_shutdown_all()
368

369 370 371

class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
    def __init__(self, loader):
372
        super().__init__(loader)
373

K
Kaipeng Deng 已提交
374 375 376
        self._persistent_workers = loader._persistent_workers
        self._resume_worker_cnt = 0

377 378 379 380 381
        assert (
            self._num_workers > 0
        ), "Multi-process DataLoader " "invalid num_workers({})".format(
            self._num_workers
        )
382 383 384 385 386

        # subprocess wrokers' result queue
        self._data_queue = None

        # data get from _data_queue will be reordered by _rcvd_idx
387
        # for data order keeping, data index not equal _rcvd_idx
388
        # will be cached in _task_infos
389 390 391
        self._send_idx = 0
        self._rcvd_idx = 0
        self._batches_outstanding = 0
392
        self._task_infos = {}
393
        self._structure_infos = []
394 395 396 397

        # indices outstand as _outstanding_capacity at first, and
        # blocking_queue capacity is also _outstanding_capacity.
        # _outstanding_capacity here to make sure each indices_queue
398 399
        # has at least "_prefetch_factor" indices, and outstanding batch cached
        # output data for at least "_prefetch_factor" iterations(Note that len(_places)
400
        # batches will be composed as an iteration output)
401
        self._outstanding_capacity = self._prefetch_factor * max(
402 403
            self._num_workers, len(self._places)
        )
404

405 406 407
        # see _try_put_indices
        self._thread_lock = threading.Lock()

408 409
        self._base_seed = np.random.randint(low=0, high=sys.maxsize)

410 411 412
        # Note(zhangbo): shm_buffer_size is used for MemoryMapAllocationPool.
        # MemoryMapAllocationPool is used to cache and reuse shm, thus reducing munmap in dataloader.
        # For more details, please see: paddle/fluid/memory/allocation/mmap_allocator.h
Z
zhangbo9674 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
        if os.environ.get('FLAGS_use_shm_cache', False) in [
            1,
            '1',
            True,
            'True',
            'true',
        ]:
            try:
                self._worker_shm_buffer_size = (2 + 1) * len(self._dataset[0])
            except:
                self._worker_shm_buffer_size = 0
                warnings.warn(
                    "Setting the shm cache buffer size to 0, equivalent to not using the shm cache policy."
                )
        else:
428 429 430 431 432
            self._worker_shm_buffer_size = 0
        self._main_thread_shm_buffer_size = (
            (self._worker_shm_buffer_size) * 2 * self._num_workers
        )

433
        # init workers and indices queues and put 2 indices in each indices queue
434 435 436 437
        self._init_workers()
        for _ in range(self._outstanding_capacity):
            self._try_put_indices()

438 439 440
        self._init_thread()
        self._shutdown = False

441
    def _init_workers(self):
442 443
        import paddle.incubate.multiprocessing as multiprocessing

444 445 446 447 448 449 450 451 452
        # multiprocess worker and indice queue list initial as empty
        self._workers = []
        self._worker_status = []
        self._indices_queues = []
        self._workers_idx_cycle = itertools.cycle(range(self._num_workers))

        # create data_queue for workers
        self._data_queue = multiprocessing.Queue()

453
        # event for workers and thread, thread event is only need
454 455 456 457 458 459 460 461
        # in multi-processing mode
        self._workers_done_event = multiprocessing.Event()
        self._thread_done_event = threading.Event()

        for i in range(self._num_workers):
            indices_queue = multiprocessing.Queue()
            self._indices_queues.append(indices_queue)
            worker = multiprocessing.Process(
462
                target=_worker_loop,
463 464 465 466 467 468 469 470 471 472 473 474 475 476
                args=(
                    self._dataset,
                    self._dataset_kind,
                    indices_queue,
                    self._data_queue,
                    self._workers_done_event,
                    self._auto_collate_batch,
                    self._collate_fn,
                    self._drop_last,
                    self._worker_init_fn,
                    i,
                    self._num_workers,
                    self._use_shared_memory,
                    self._base_seed,
477
                    self._worker_shm_buffer_size,
478 479
                ),
            )
480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
            worker.daemon = True
            worker.start()
            self._workers.append(worker)
            self._worker_status.append(True)

        core._set_process_pids(id(self), tuple(w.pid for w in self._workers))
        _set_SIGCHLD_handler()

    def _clear_and_remove_data_queue(self):
        if self._data_queue is not None:
            while True:
                try:
                    self._data_queue.get_nowait()
                except:
                    self._data_queue.cancel_join_thread()
                    self._data_queue.close()
                    break

    def _init_thread(self):
        self._var_names = [v.name for v in self._feed_list]
        self._shapes = [v.shape for v in self._feed_list]
        self._dtypes = [v.dtype for v in self._feed_list]
        self._need_check_feed = [
            v.desc.need_check_feed() for v in self._feed_list
        ]
505
        # if only 1 place, do not need to keep order
506
        self._blocking_queue = core.init_lod_tensor_blocking_queue(
507 508
            core.Variable(), self._outstanding_capacity, len(self._places) > 1
        )
509 510 511
        core._set_max_memory_map_allocation_pool_size(
            self._main_thread_shm_buffer_size
        )
512
        self._reader = core.create_py_reader(
513 514 515 516 517 518 519 520 521 522
            self._blocking_queue,
            self._var_names,
            self._shapes,
            self._dtypes,
            self._need_check_feed,
            self._places,
            self._use_buffer_reader,
            True,
            self._pin_memory,
        )
523 524

        self._thread_done_event = threading.Event()
K
Kaipeng Deng 已提交
525
        # thread event is only need in multi-processing mode
526 527 528
        self._thread = threading.Thread(
            target=self._thread_loop, args=(_current_expected_place(),)
        )
529 530 531
        self._thread.daemon = True
        self._thread.start()

K
Kaipeng Deng 已提交
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
    def _reset(self):
        # resume iteration in following steps
        # 1. Resume workers, clear worker caches
        # put _ResumeIteration to all worker as resume iteration flag
        with self._thread_lock:
            self._resume_worker_cnt = self._num_workers
            for worker_id in range(self._num_workers):
                self._indices_queues[worker_id].put(_ResumeIteration())
                self._batches_outstanding += 1
        # all flag will be check in _thread_loop, simply wait here
        while self._resume_worker_cnt > 0:
            time.sleep(0.5)

        # 2. clear blocking_queue caches
        # in order not to restart the thread, we just clear
        # the blocking_queue cachees instead of recreating one
        while self._blocking_queue.size() >= len(self._places):
            if in_dygraph_mode():
J
Jiabin Yang 已提交
550
                data = core.eager.read_next_tensor_list(
551 552
                    self._reader.read_next_list()[0]
                )
K
Kaipeng Deng 已提交
553
            else:
姜永久 已提交
554
                if self._return_list:
J
Jiabin Yang 已提交
555 556 557
                    self._reader.read_next_list()
                else:
                    data = self._reader.read_next()
K
Kaipeng Deng 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575

        # 3. reset all states
        self._send_idx = 0
        self._rcvd_idx = 0
        self._batches_outstanding = 0
        self._task_infos = {}
        self._structure_infos = []

        # set all worker status available
        self._worker_status = [True] * self._num_workers

        # 4. reset _sampler_iter and put prefetch indices to start next epoch
        # init workers and indices queues and put 2 indices in each indices queue
        self._sampler_iter = iter(self._index_sampler)
        for _ in range(self._outstanding_capacity):
            self._try_put_indices()

    def _shutdown_worker(self, worker_id, shutdown=False):
576 577 578
        if self._worker_status[worker_id] or (
            self._persistent_workers and shutdown
        ):
579 580 581
            self._indices_queues[worker_id].put(None)
            self._worker_status[worker_id] = False

582
    def _try_shutdown_all(self, timeout=None):
583 584 585 586 587 588 589 590 591 592
        if not self._shutdown:
            try:
                self._exit_thread_expectedly()
                self._clear_and_remove_data_queue()

                # set _workers_done_event should be set before put None
                # to indices_queue, workers wll exit on reading None from
                # indices_queue
                self._workers_done_event.set()
                for i in range(self._num_workers):
K
Kaipeng Deng 已提交
593
                    self._shutdown_worker(i, shutdown=True)
594

595 596 597 598 599 600
                if not self._shutdown:
                    for w in self._workers:
                        w.join(timeout)
                    for q in self._indices_queues:
                        q.cancel_join_thread()
                        q.close()
601 602 603 604
            finally:
                core._erase_process_pids(id(self))
                self._shutdown = True

605
    def _thread_loop(self, legacy_expected_place):
606
        # NOTE(zhiqiu): Set the expected place for new thread as the same as father thread,
607 608
        # and it will call platform::SetDeviceId() in c++ internally.
        # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0,
609
        # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda
610
        # APIs in this thread.
L
Leo Chen 已提交
611
        core.set_current_thread_name("Dataloader_" + str(id(self)))
612 613
        _set_expected_place(legacy_expected_place)

614 615 616 617 618 619
        while not self._thread_done_event.is_set():
            batch = self._get_data()
            if not self._thread_done_event.is_set():
                if batch is None:
                    self._exit_thread_expectedly()
                else:
K
Kaipeng Deng 已提交
620 621 622 623
                    if isinstance(batch, _ResumeIteration):
                        assert self._resume_worker_cnt > 0
                        self._resume_worker_cnt -= 1
                        continue
624 625 626 627 628 629 630 631 632 633
                    try:
                        # pack as LoDTensorArray
                        array = core.LoDTensorArray()
                        if self._use_shared_memory:
                            for tensor in batch:
                                array.append(tensor)
                        else:
                            # LoDTensor not in shared memory is not
                            # serializable, cannot be create in workers
                            for slot in batch:
634
                                if isinstance(
635 636
                                    slot, (paddle.Tensor, core.eager.Tensor)
                                ):
K
Kaipeng Deng 已提交
637 638
                                    slot = slot.value().get_tensor()
                                elif not isinstance(slot, core.LoDTensor):
639 640 641 642 643 644 645
                                    tmp = core.LoDTensor()
                                    tmp.set(slot, core.CPUPlace())
                                    slot = tmp
                                array.append(slot)

                        if not self._blocking_queue.push(array):
                            self._blocking_queue.close()
K
Kaipeng Deng 已提交
646
                    except Exception as e:
647
                        self._exit_thread_unexpectedly()
648
                        raise e
649 650 651 652 653
                    finally:
                        self._rcvd_idx += 1

    def _get_data(self):
        while not self._thread_done_event.is_set():
654 655 656
            # For IterableDataset, batch indices is generated infinitely
            # for each worker to raise StopIteration, but a StopIteration
            # raising process will discard a batch indices which is count
657
            # in _send_idx but will not increase _rcvd_idx, so we check
658 659
            # whether the worker is still alive here to skip the discarded
            # batch indices and increase _rcvd_idx
660 661 662
            if self._dataset_kind == _DatasetKind.ITER:
                while self._rcvd_idx < self._send_idx:
                    info = self._task_infos[self._rcvd_idx]
663
                    if len(info) == 3 or self._worker_status[info[0]]:
664 665 666 667 668
                        break
                    del self._task_infos[self._rcvd_idx]
                    self._rcvd_idx += 1
                    self._batches_outstanding -= 1
                else:
669 670 671 672 673 674 675 676
                    # NOTE: when _rcvd_idx catch up _send_idx, which means
                    #       one of following:
                    #       1. all 2 * num_workers batches have been loaded
                    #          and stored in _blocking_queue
                    #       2. all data drained
                    #       we need to let _thread blocking at _data_queue
                    #       get_data to inoccupy CPU, otherwise may occupy
                    #       CPU time for model running
K
Kaipeng Deng 已提交
677 678 679 680 681 682 683 684 685
                    # NOTE: in persistent workers mode, do not check data
                    #       drained here, simply let it go to _data_queue
                    #       reading to get _ResumeIteration
                    if not self._persistent_workers:
                        # NOTE: _rcvd_idx and _send_idx only record batches among
                        #       workers, if batches among workers drained, there
                        #       may also be data in blocking queue
                        if self._batches_outstanding < len(self._places):
                            return None
686

687 688 689 690
            if (
                self._rcvd_idx in self._task_infos
                and len(self._task_infos[self._rcvd_idx]) == 3
            ):
691 692 693
                info = self._task_infos.pop(self._rcvd_idx)
                self._structure_infos.append(info[2])
                return info[1]
694

695 696 697
            try:
                # [ avoid hang ]: main process may blocking at _reader.read_next when
                # KeyboardInterrupt, we do following tradeoff:
698
                # 1. get data with timeout, MP_STATUS_CHECK_INTERVAL(5s) as timeout
699 700 701 702 703 704 705
                #    default, if KeyboardInterrupt blocking, failed workers will be
                #    checked and raise RuntimeError to quit DataLoader in timeout
                #    exception handling.
                # 2. if get data timeout and check workers all alive, continue to
                #    get data again
                data = self._data_queue.get(timeout=self._timeout)
            except Exception as e:
706 707 708 709 710
                # check if thread done event set when waiting data
                if self._thread_done_event.is_set():
                    continue

                # check failed workers
711 712 713 714 715 716 717 718
                failed_workers = []
                for i, w in enumerate(self._workers):
                    if self._worker_status[i] and not w.is_alive():
                        failed_workers.append(w)
                        self._shutdown_worker(i)
                if len(failed_workers) > 0:
                    self._exit_thread_unexpectedly()
                    pids = ', '.join(str(w.pid) for w in failed_workers)
719 720 721 722
                    raise RuntimeError(
                        "DataLoader {} workers exit unexpectedly, "
                        "pids: {}".format(len(failed_workers), pids)
                    )
723 724 725 726 727 728 729

                # get(timeout) will call _poll(timeout) and may raise IOError
                if isinstance(e, queue.Empty) or isinstance(e, IOError):
                    # continue on timeout to keep getting data from queue
                    continue

                self._exit_thread_unexpectedly()
730 731 732 733
                logging.error(
                    "DataLoader reader thread failed({}) to read data from "
                    "workers' result queue.".format(e)
                )
734
                raise e
735
            else:
736
                if self._dataset_kind == _DatasetKind.ITER and isinstance(
737 738
                    data, _IterableDatasetStopIteration
                ):
739 740 741 742 743
                    # if a worker get StopIteraion, we shutdown this worker,
                    # note that this batch indices to trigger StopIteration
                    # is discard, outstanding batch number should be decrease
                    # and another indices should be put for other workers
                    # may still working.
K
Kaipeng Deng 已提交
744 745 746 747 748
                    if self._persistent_workers:
                        self._worker_status[data.worker_id] = False
                    else:
                        self._shutdown_worker(data.worker_id)
                        self._batches_outstanding -= 1
749 750 751
                    self._try_put_indices()
                    continue

752
                idx, batch, structure = data
K
Kaipeng Deng 已提交
753

754 755 756 757 758
                if (
                    isinstance(idx, _ResumeIteration)
                    and batch is None
                    and structure is None
                ):
K
Kaipeng Deng 已提交
759 760
                    return idx

761 762 763 764
                if isinstance(batch, _WorkerException):
                    self._exit_thread_unexpectedly()
                    batch.reraise()

765
                if idx == self._rcvd_idx:
766
                    del self._task_infos[idx]
767
                    self._structure_infos.append(structure)
768 769
                    return batch
                else:
770
                    self._task_infos[idx] += (batch, structure)
771 772 773
                    continue

    def _try_put_indices(self):
774 775 776
        assert (
            self._batches_outstanding <= self._outstanding_capacity
        ), "too many indices have been put to queue"
777 778 779 780 781 782 783 784 785 786 787 788 789 790
        # In multi-process mode for IterableDataset, _try_put_indices will
        # be called both in main process(for our implement has blocking queue,
        # and blocking queue read is in main process) and thread, which may
        # cause error following error
        #   1. "ValueError: generator already executing" in next(self._sampler_iter)
        #   2. re-enter in increase _send_idx
        # add a lock for threading save, for _try_put_indices is only a slight
        # function which is not in data reading pipeline, this lock almost no
        # influence on performance
        with self._thread_lock:
            try:
                indices = next(self._sampler_iter)
            except StopIteration:
                return
791

792 793 794 795 796 797
            for i in range(self._num_workers):
                worker_idx = next(self._workers_idx_cycle)
                if self._worker_status[worker_idx]:
                    break
            else:
                return
798

799
            self._indices_queues[worker_idx].put((self._send_idx, indices))
800
            self._task_infos[self._send_idx] = (worker_idx,)
801 802
            self._batches_outstanding += 1
            self._send_idx += 1
803 804 805 806

    def __del__(self):
        self._try_shutdown_all()

807 808 809
    def _shutdown_on_exit(self):
        self._try_shutdown_all(1)

810
    def __next__(self):
811 812 813
        if in_profiler_mode():
            trace_event = profiler.RecordEvent(
                name="_DataLoaderIterMultiProcess",
814 815
                event_type=profiler.TracerEventType.Dataloader,
            )
816
            trace_event.begin()
817
        try:
Z
Zhang Ting 已提交
818 819
            benchmark().check_if_need_record(self)
            benchmark().before_reader()
820 821 822 823 824 825 826 827
            # _batches_outstanding here record the total batch data number
            # in 'from after _try_put_indices to beforeoutput data', this
            # value should be _outstanding_capacity if data is not drained,
            # if _batches_outstanding is less than _places number, there are
            # no enough data to generate next output, close blocking_queue and
            # set _thread_done_event here, py_reader will raise StopIteration,
            # end workers and indices_queues in StopIteration handling
            if self._batches_outstanding < len(self._places):
K
Kaipeng Deng 已提交
828 829 830 831 832
                if self._persistent_workers:
                    raise StopIteration
                else:
                    self._thread_done_event.set()
                    self._blocking_queue.close()
833 834

            if in_dygraph_mode():
J
Jiabin Yang 已提交
835
                data = core.eager.read_next_tensor_list(
836 837
                    self._reader.read_next_list()[0]
                )
838
                data = _restore_batch(data, self._structure_infos.pop(0))
839
            else:
姜永久 已提交
840 841 842 843 844 845 846 847 848 849 850 851 852 853
                if self._return_list:
                    data = self._reader.read_next_list()
                    for i in range(len(data)):
                        data[i] = data[i]._move_to_list()
                    structs = [
                        self._structure_infos.pop(0)
                        for _ in range(len(self._places))
                    ]
                    data = [_restore_batch(d, s) for d, s in zip(data, structs)]
                    # static graph organized data on multi-device with list, if
                    # place number is 1, there is only 1 device, extra the data
                    # from list for devices to be compatible with dygraph mode
                    if len(self._places) == 1:
                        data = data[0]
854
                else:
姜永久 已提交
855
                    data = self._reader.read_next()
856
            self._on_output_batch()
Z
Zhang Ting 已提交
857
            benchmark().after_reader()
858 859
            return data
        except StopIteration:
K
Kaipeng Deng 已提交
860 861 862
            if not self._persistent_workers:
                self._reader.shutdown()
                self._try_shutdown_all()
863
            raise
C
chenjian 已提交
864
        finally:
865 866
            if in_profiler_mode():
                trace_event.end()
867 868 869 870 871

    def _on_output_batch(self):
        for _ in range(len(self._places)):
            self._batches_outstanding -= 1
            self._try_put_indices()