未验证 提交 dbc88bb9 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add iterable dataset support for multiprocess DataLoader (#25558)

* add IterableDataset support in multiprocess DataLoader. test=develop
上级 54003b87
...@@ -20,5 +20,9 @@ from .dataset import * ...@@ -20,5 +20,9 @@ from .dataset import *
from . import batch_sampler from . import batch_sampler
from .batch_sampler import * from .batch_sampler import *
from . import dataloader_iter
from .dataloader_iter import *
__all__ = dataset.__all__ \ __all__ = dataset.__all__ \
+ batch_sampler.__all__ + batch_sampler.__all__ \
+ dataloader_iter.__all__
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from .dataset import Dataset from .dataset import Dataset, IterableDataset
__all__ = ["BatchSampler"] __all__ = ["BatchSampler"]
...@@ -106,12 +106,18 @@ class BatchSampler(object): ...@@ -106,12 +106,18 @@ class BatchSampler(object):
assert isinstance(indices, list) or isinstance(indices, tuple), \ assert isinstance(indices, list) or isinstance(indices, tuple), \
"indices should be a list or tuple, but got {}".format(type(indices)) "indices should be a list or tuple, but got {}".format(type(indices))
self.indices = indices self.indices = indices
self.sampler_iter = None
else: else:
assert isinstance(dataset, Dataset), \ if isinstance(dataset, IterableDataset):
"dataset should be an instance of paddle.io.Dataset" self.sampler_iter = iter(
assert indices is None, \ _InfiniteIterableSampler(dataset, batch_size))
"should not set both dataset and indices" else:
self.indices = list(range(len(dataset))) self.sampler_iter = None
assert isinstance(dataset, Dataset), \
"dataset should be an instance of paddle.io.Dataset"
assert indices is None, \
"should not set both dataset and indices"
self.indices = list(range(len(dataset)))
assert isinstance(batch_size, int) and batch_size > 0, \ assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size) "batch_size should be a positive integer, but got {}".format(batch_size)
...@@ -124,6 +130,9 @@ class BatchSampler(object): ...@@ -124,6 +130,9 @@ class BatchSampler(object):
self.drop_last = drop_last self.drop_last = drop_last
def __iter__(self): def __iter__(self):
if self.sampler_iter:
yield next(self.sampler_iter)
if self.shuffle: if self.shuffle:
np.random.shuffle(self.indices) np.random.shuffle(self.indices)
_iter = iter(self.indices) _iter = iter(self.indices)
...@@ -138,6 +147,22 @@ class BatchSampler(object): ...@@ -138,6 +147,22 @@ class BatchSampler(object):
yield batch_indices yield batch_indices
def __len__(self): def __len__(self):
if self.sampler_iter:
raise RuntimeError("'{}' should not be called for IterableDataset".
format('__len__'))
num_samples = len(self.indices) num_samples = len(self.indices)
num_samples += int(not self.drop_last) * (self.batch_size - 1) num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size return num_samples // self.batch_size
class _InfiniteIterableSampler(object):
def __init__(self, dataset, batch_size=1):
assert isinstance(
dataset, IterableDataset
), "dataset should be an instance of paddle.io.IterableDataset"
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
while True:
yield [None] * self.batch_size
...@@ -22,6 +22,7 @@ import itertools ...@@ -22,6 +22,7 @@ import itertools
import threading import threading
import numpy as np import numpy as np
import multiprocessing import multiprocessing
from collections import namedtuple
# NOTE: queue has a different name in python2 and python3 # NOTE: queue has a different name in python2 and python3
if six.PY2: if six.PY2:
...@@ -32,11 +33,17 @@ else: ...@@ -32,11 +33,17 @@ else:
from .. import core from .. import core
from ..framework import in_dygraph_mode from ..framework import in_dygraph_mode
from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
__all__ = ['get_worker_info']
# multi-process worker check indices queue interval, avoid # multi-process worker check indices queue interval, avoid
# hanging in subprocess data loading # hanging in subprocess data loading
MP_INDICES_CHECK_INTERVAL = 5 MP_INDICES_CHECK_INTERVAL = 5
_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration',
['worker_id'])
def default_collate_fn(batch): def default_collate_fn(batch):
""" """
...@@ -75,6 +82,20 @@ def default_collate_fn(batch): ...@@ -75,6 +82,20 @@ def default_collate_fn(batch):
return [np.stack(slot, axis=0) for slot in slots] return [np.stack(slot, axis=0) for slot in slots]
class _DatasetKind(object):
MAP = 0
ITER = 1
@staticmethod
def create_fetcher(kind, dataset, collate_fn, drop_last):
if kind == _DatasetKind.MAP:
return _MapDatasetFetcher(dataset, collate_fn, drop_last)
elif kind == _DatasetKind.ITER:
return _IterableDatasetFetcher(dataset, collate_fn, drop_last)
else:
raise NotImplementedError("unknown Dataset kind {}".format(kind))
class ParentWatchDog(object): class ParentWatchDog(object):
def __init__(self): def __init__(self):
self._parent_pid = os.getppid() self._parent_pid = os.getppid()
...@@ -86,6 +107,92 @@ class ParentWatchDog(object): ...@@ -86,6 +107,92 @@ class ParentWatchDog(object):
return self._parent_alive return self._parent_alive
# worker information for each workers, used for splitting data copy
# for IteratorDataset in worker processes.
_worker_info = None
def get_worker_info():
"""
Get DataLoader worker process information function, this function is
used to split data copy in worker process for IterableDataset
(see :code:`paddle.io.IterableDataset`), worker information contains
following fields:
:attr:`num_workers`: total worker process number, see `paddle.io.DataLoader`
:attr:`id`: the worker processs id, count from 0 to :attr:`num_workers - 1`
:attr:`dataset`: the dataset object in this worker process
Returns:
WorkerInfo: an instance of WorkerInfo which contains fields above.
.. note::
For mode usage and exampls, please see :code:`paddle.io.IterableDataset`
Example:
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
"""
return _worker_info
class WorkerInfo(object):
__initialized = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self.__initialized = True
def __setattr__(self, key, val):
if self.__initialized:
raise RuntimeError("Cannot assign attributes to {} objects".format(
self.__class__.__name__))
return super(WorkerInfo, self).__setattr__(key, val)
class _DataLoaderIterBase(object): class _DataLoaderIterBase(object):
""" """
Iterator implement of DataLoader, will load and feed mini-batch Iterator implement of DataLoader, will load and feed mini-batch
...@@ -108,6 +215,7 @@ class _DataLoaderIterBase(object): ...@@ -108,6 +215,7 @@ class _DataLoaderIterBase(object):
self._use_shared_memory = loader.use_shared_memory self._use_shared_memory = loader.use_shared_memory
self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL self._timeout = loader.timeout if loader.timeout > 0 else MP_INDICES_CHECK_INTERVAL
self._worker_init_fn = loader.worker_init_fn self._worker_init_fn = loader.worker_init_fn
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory self._pin_memory = loader.pin_memory
# LoDTensorBlockingQueue instance for create_py_reader and a thread # LoDTensorBlockingQueue instance for create_py_reader and a thread
...@@ -135,6 +243,9 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -135,6 +243,9 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
def __init__(self, loader): def __init__(self, loader):
super(_DataLoaderIterSingleProcess, self).__init__(loader) super(_DataLoaderIterSingleProcess, self).__init__(loader)
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._collate_fn, True)
# NOTE: len(self._places) batch data compose as an output # NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas # iteration, set blocking_queue can cache 2 iteration datas
# at most here # at most here
...@@ -166,9 +277,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -166,9 +277,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
try: try:
for indices in self._sampler_iter: for indices in self._sampler_iter:
# read data from dataset in mini-batch # read data from dataset in mini-batch
batch = [self._dataset[i] for i in indices] batch = self._dataset_fetcher.fetch(indices)
if self._collate_fn is not None:
batch = self._collate_fn(batch)
# pack as LoDTensorArray # pack as LoDTensorArray
array = core.LoDTensorArray() array = core.LoDTensorArray()
...@@ -186,6 +295,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -186,6 +295,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
self._blocking_queue.close() self._blocking_queue.close()
self._thread = None self._thread = None
except StopIteration:
self._blocking_queue.close()
except Exception: except Exception:
self._blocking_queue.kill() self._blocking_queue.kill()
self._thread = None self._thread = None
...@@ -233,11 +344,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -233,11 +344,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# data get from _data_queue will be reordered by _rcvd_idx # data get from _data_queue will be reordered by _rcvd_idx
# for data order keeping, data index not equal _rcvd_idx # for data order keeping, data index not equal _rcvd_idx
# will be cached in _reorder_dict # will be cached in _task_infos
self._send_idx = 0 self._send_idx = 0
self._rcvd_idx = 0 self._rcvd_idx = 0
self._batches_outstanding = 0 self._batches_outstanding = 0
self._reorder_dict = {} self._task_infos = {}
# indices outstand as _outstanding_capacity at first, and # indices outstand as _outstanding_capacity at first, and
# blocking_queue capacity is also _outstanding_capacity. # blocking_queue capacity is also _outstanding_capacity.
...@@ -248,14 +359,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -248,14 +359,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._outstanding_capacity = 2 * max(self._num_workers, self._outstanding_capacity = 2 * max(self._num_workers,
len(self._places)) len(self._places))
# init workers and indices queues and put 2 indices in each indices queue
self._init_workers() self._init_workers()
self._init_thread()
self._shutdown = False
for _ in range(self._outstanding_capacity): for _ in range(self._outstanding_capacity):
self._try_put_indices() self._try_put_indices()
self._init_thread()
self._shutdown = False
def _init_workers(self): def _init_workers(self):
# multiprocess worker and indice queue list initial as empty # multiprocess worker and indice queue list initial as empty
self._workers = [] self._workers = []
...@@ -276,9 +387,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -276,9 +387,10 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._indices_queues.append(indices_queue) self._indices_queues.append(indices_queue)
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._worker_loop, target=self._worker_loop,
args=(self._dataset, indices_queue, self._data_queue, args=(self._dataset, self._dataset_kind, indices_queue,
self._workers_done_event, self._collate_fn, self._data_queue, self._workers_done_event,
self._worker_init_fn, i)) self._collate_fn, self._worker_init_fn, i,
self._num_workers))
worker.daemon = True worker.daemon = True
worker.start() worker.start()
self._workers.append(worker) self._workers.append(worker)
...@@ -353,8 +465,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -353,8 +465,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._blocking_queue.kill() self._blocking_queue.kill()
logging.error("DataLoader reader thread raised an exception!") logging.error("DataLoader reader thread raised an exception!")
def _worker_loop(self, dataset, indices_queue, out_queue, done_event, def _worker_loop(self, dataset, dataset_kind, indices_queue, out_queue,
collate_fn, init_fn, worker_id): done_event, collate_fn, init_fn, worker_id, num_workers):
try: try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly, # NOTE: [ mmap files clear ] When the child process exits unexpectedly,
# some shared memory objects may have been applied for but have not yet # some shared memory objects may have been applied for but have not yet
...@@ -365,14 +477,21 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -365,14 +477,21 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# set signal handler # set signal handler
core._set_process_signal_handler() core._set_process_signal_handler()
global _worker_info
_worker_info = WorkerInfo(
id=worker_id, num_workers=num_workers, dataset=dataset)
init_exception = None init_exception = None
if init_fn is not None: try:
try: if init_fn is not None:
init_fn(worker_id) init_fn(worker_id)
except: fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
init_exception = Exception("init_fn failed in worker {}: " \ collate_fn, True)
"{}".format(worker_id, sys.exc_info())) except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
iterator_drained = False
parent_watch_dog = ParentWatchDog() parent_watch_dog = ParentWatchDog()
while parent_watch_dog.is_alive(): while parent_watch_dog.is_alive():
...@@ -383,12 +502,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -383,12 +502,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# None as poison piil, so worker event should be set # None as poison piil, so worker event should be set
if data is None: if data is None:
assert done_event.is_set( assert done_event.is_set() or iterator_drained, \
), "get None when worker done_event set" "get None when worker done_event set"
break break
# If worker done event is set but get still get data in # If worker done event is set but get still get data in
# indices_queue, remaining data should be get and skipped. # indices_queue, remaining data should be get and skipped.
if done_event.is_set(): if done_event.is_set() or iterator_drained:
continue continue
idx, indices = data idx, indices = data
...@@ -397,11 +516,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -397,11 +516,15 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
batch = init_exception batch = init_exception
init_exception = None init_exception = None
else: else:
batch = [dataset[i] for i in indices] batch = fetcher.fetch(indices)
if self._collate_fn is not None:
batch = self._collate_fn(batch)
except Exception as e: except Exception as e:
out_queue.put((idx, e)) if isinstance(
e,
StopIteration) and dataset_kind == _DatasetKind.ITER:
out_queue.put(_IterableDatasetStopIteration(worker_id))
iterator_drained = True
else:
out_queue.put((idx, e))
else: else:
if self._use_shared_memory: if self._use_shared_memory:
tensor_list = core._convert_to_tensor_list(batch) tensor_list = core._convert_to_tensor_list(batch)
...@@ -438,7 +561,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -438,7 +561,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
# serializable, cannot be create in workers # serializable, cannot be create in workers
for slot in batch: for slot in batch:
if not isinstance(slot, core.LoDTensor): if not isinstance(slot, core.LoDTensor):
# self._check_input_array(slot)
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(slot, core.CPUPlace()) tmp.set(slot, core.CPUPlace())
slot = tmp slot = tmp
...@@ -453,10 +575,31 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -453,10 +575,31 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
self._rcvd_idx += 1 self._rcvd_idx += 1
def _get_data(self): def _get_data(self):
if self._rcvd_idx in self._reorder_dict.keys():
return self._reorder_dict.pop(self._rcvd_idx)
while not self._thread_done_event.is_set(): while not self._thread_done_event.is_set():
# For IterableDataset, batch indices is generated infinitely
# for each worker to raise StopIteration, but a StopIteration
# raising process will discard a batch indices which is count
# in _send_idx but will not increase _rcvd_idx, so we check
# whether the worker is still alive here to skip the discarded
# batch indices and increase _rcvd_idx
while self._rcvd_idx < self._send_idx:
info = self._task_infos[self._rcvd_idx]
if len(info) == 2 or self._worker_status[info[0]]:
break
del self._task_infos[self._rcvd_idx]
self._rcvd_idx += 1
self._batches_outstanding -= 1
else:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
continue
if len(self._task_infos[self._rcvd_idx]) == 2:
return self._task_infos.pop(self._rcvd_idx)[1]
try: try:
# [ avoid hang ]: main process may blocking at _reader.read_next when # [ avoid hang ]: main process may blocking at _reader.read_next when
# KeyboardInterrupt, we do following tradeoff: # KeyboardInterrupt, we do following tradeoff:
...@@ -494,23 +637,43 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -494,23 +637,43 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
"workers' result queue.".format(e)) "workers' result queue.".format(e))
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
else: else:
if self._dataset_kind == _DatasetKind.ITER and isinstance(
data, _IterableDatasetStopIteration):
# if a worker get StopIteraion, we shutdown this worker,
# note that this batch indices to trigger StopIteration
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
self._try_put_indices()
continue
idx, batch = data idx, batch = data
if idx == self._rcvd_idx: if idx == self._rcvd_idx:
del self._task_infos[idx]
return batch return batch
else: else:
self._reorder_dict[idx] = batch self._task_infos[idx] += (batch, )
continue continue
def _try_put_indices(self): def _try_put_indices(self):
assert self._send_idx - self._rcvd_idx <= self._outstanding_capacity, \ assert self._batches_outstanding <= self._outstanding_capacity, \
"too many indices have been put to queue" "too many indices have been put to queue"
try: try:
indices = next(self._sampler_iter) indices = next(self._sampler_iter)
except StopIteration: except StopIteration:
return return
worker_idx = next(self._workers_idx_cycle) for i in range(self._num_workers):
worker_idx = next(self._workers_idx_cycle)
if self._worker_status[worker_idx]:
break
else:
return
self._indices_queues[worker_idx].put((self._send_idx, indices)) self._indices_queues[worker_idx].put((self._send_idx, indices))
self._task_infos[self._send_idx] = (worker_idx, )
self._batches_outstanding += 1 self._batches_outstanding += 1
self._send_idx += 1 self._send_idx += 1
......
...@@ -16,12 +16,12 @@ from __future__ import print_function ...@@ -16,12 +16,12 @@ from __future__ import print_function
import paddle.dataset.common import paddle.dataset.common
__all__ = ["Dataset"] __all__ = ["Dataset", "IterableDataset"]
class Dataset(object): class Dataset(object):
""" """
An abstract class to encapsulates methods and behaviors of datasets. An abstract class to encapsulate methods and behaviors of datasets.
All datasets in map-style(dataset samples can be get by a given key) All datasets in map-style(dataset samples can be get by a given key)
should be a subclass of `paddle.io.Dataset`. All subclasses should should be a subclass of `paddle.io.Dataset`. All subclasses should
...@@ -71,3 +71,154 @@ class Dataset(object): ...@@ -71,3 +71,154 @@ class Dataset(object):
def __len__(self): def __len__(self):
raise NotImplementedError("'{}' not implement in class "\ raise NotImplementedError("'{}' not implement in class "\
"{}".format('__len__', self.__class__.__name__)) "{}".format('__len__', self.__class__.__name__))
class IterableDataset(Dataset):
"""
An abstract class to encapsulate methods and behaviors of iterable datasets.
All datasets in iterable-style (can only get sample one by one sequentially, like
a Python iterator) should be a subclass of `paddle.io.IterableDataset`. All subclasses should
implement following methods:
:code:`__iter__`: yield sample sequentially. This method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
.. note::
do not implement :code:`__getitem__` and :code:`__len__` in IterableDataset, should not be called either.
see :code:`paddle.io.DataLoader`.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset
# define a random dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
for i in range(self.num_samples):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
yield image, label
dataset = RandomDataset(10)
for img, lbl in dataset:
print(img, lbl)
When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
will yield whole dataset samples, which means samples in dataset will be repeated in
:attr:`num_workers` times. If it is required for each sample to yield only once, there
are two methods to configure different copy in each worker process to avoid duplicate data
among workers as follows. In both the methods, worker information that can be getted in
a worker process by `paddle.io.get_worker_info` will be needed.
Example 1: splitting data copy in each worker in :code:`__iter__`
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class SplitedIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = SplitedIterableDataset(start=2, end=9)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
Example 2: splitting data copy in each worker by :code:`worker_init_fn`
.. code-block:: python
import math
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, DataLoader, get_worker_info
class RangeIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
for i in range(self.start, self.end):
yield np.array([i])
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = RangeIterableDataset(start=2, end=9)
def worker_init_fn(worker_id):
worker_info = get_worker_info()
dataset = worker_info.dataset
start = dataset.start
end = dataset.end
num_per_worker = int(
math.ceil((end - start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = start + worker_id * num_per_worker
dataset.end = min(dataset.start + num_per_worker, end)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True,
worker_init_fn=worker_init_fn)
print(list(dataloader))
# outputs: [2, 5, 3, 6, 4, 7]
"""
def __init__(self):
pass
def __iter__(self):
raise NotImplementedError("'{}' not implement in class "\
"{}".format('__iter__', self.__class__.__name__))
def __getitem__(self, idx):
raise RuntimeError("'{}' should not be called for IterableDataset" \
"{}".format('__getitem__', self.__class__.__name__))
def __len__(self):
raise RuntimeError("'{}' should not be called for IterableDataset" \
"{}".format('__len__', self.__class__.__name__))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class _DatasetFetcher(object):
def __init__(self, dataset, collate_fn, drop_last):
self.dataset = dataset
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, batch_indices):
raise NotImplementedError("'fetch' not implement for class {}".format(
self.__class__.__name__))
class _IterableDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, collate_fn,
drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, batch_indices):
data = []
for _ in batch_indices:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)):
raise StopIteration
return self.collate_fn(data)
class _MapDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, collate_fn, drop_last)
def fetch(self, batch_indices):
data = [self.dataset[idx] for idx in batch_indices]
return self.collate_fn(data)
...@@ -22,8 +22,9 @@ from .framework import Program, Variable, program_guard, default_main_program, d ...@@ -22,8 +22,9 @@ from .framework import Program, Variable, program_guard, default_main_program, d
from .executor import global_scope from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
from .dataloader import BatchSampler, Dataset from .dataloader import BatchSampler, Dataset, IterableDataset
from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, default_collate_fn from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoaderIterMultiProcess, _DatasetKind, default_collate_fn
from .dataloader.batch_sampler import _InfiniteIterableSampler
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
from .unique_name import UniqueNameGenerator from .unique_name import UniqueNameGenerator
import logging import logging
...@@ -136,8 +137,9 @@ class DataLoader(object): ...@@ -136,8 +137,9 @@ class DataLoader(object):
Args: Args:
dataset(Dataset): the dataset to load data from, should be an dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset`. instance of subclass of :code:`paddle.io.Dataset` or
feed_list (list(Variable)|tuple(Variable)): feed variable list. :code:`paddle.io.IterableDataset`.
feed_list (list(Tensor)|tuple(Tensor)): feed variable list.
The variables should be created by :code:`fluid.data()`. The variables should be created by :code:`fluid.data()`.
:attr:`feed_list` must be set if :attr:`return_list` is :attr:`feed_list` must be set if :attr:`return_list` is
False. Default None. False. Default None.
...@@ -295,6 +297,10 @@ class DataLoader(object): ...@@ -295,6 +297,10 @@ class DataLoader(object):
# ------------------------------------------------------- # -------------------------------------------------------
.. note::
For reading iterable dataset with multiprocess Dataloader,
please see :code:`paddle.io.IterableDataset`
""" """
def __init__(self, def __init__(self,
...@@ -348,6 +354,18 @@ class DataLoader(object): ...@@ -348,6 +354,18 @@ class DataLoader(object):
assert timeout >= 0, "timeout should be a non-negative value" assert timeout >= 0, "timeout should be a non-negative value"
self.timeout = timeout self.timeout = timeout
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
if shuffle:
raise ValueError(
"IterableDataset not support shuffle, but got shuffle={}".
format(shuffle))
if batch_sampler is not None:
raise ValueError(
"IterableDataset expect unspecified batch_sampler")
else:
self.dataset_kind = _DatasetKind.MAP
if batch_sampler is not None: if batch_sampler is not None:
assert isinstance(batch_sampler, BatchSampler), \ assert isinstance(batch_sampler, BatchSampler), \
"batch_sampler should be None or subclass instance " \ "batch_sampler should be None or subclass instance " \
...@@ -360,11 +378,15 @@ class DataLoader(object): ...@@ -360,11 +378,15 @@ class DataLoader(object):
assert batch_size is not None and batch_size > 0, \ assert batch_size is not None and batch_size > 0, \
"batch_size should be a positive value when " \ "batch_size should be a positive value when " \
"batch_sampler is not given" "batch_sampler is not given"
self.batch_sampler = BatchSampler( if isinstance(dataset, IterableDataset):
dataset=dataset, self.batch_sampler = _InfiniteIterableSampler(dataset,
batch_size=batch_size, batch_size)
shuffle=shuffle, else:
drop_last=drop_last) self.batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
self.pin_memory = False self.pin_memory = False
if in_dygraph_mode(): if in_dygraph_mode():
......
...@@ -278,6 +278,7 @@ if (APPLE OR WIN32) ...@@ -278,6 +278,7 @@ if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_static) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_static)
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dynamic) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dynamic)
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_exception)
list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_iterable_dataset)
endif() endif()
if(NOT WITH_GPU OR WIN32 OR APPLE) if(NOT WITH_GPU OR WIN32 OR APPLE)
...@@ -496,4 +497,6 @@ if(NOT WIN32 AND NOT APPLE) ...@@ -496,4 +497,6 @@ if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_multiprocess_dataloader_static PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_multiprocess_dataloader_static PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_multiprocess_dataloader_dynamic PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_multiprocess_dataloader_dynamic PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_multiprocess_dataloader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_multiprocess_dataloader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_multiprocess_dataloader_iterable_dataset_dynamic PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
endif() endif()
...@@ -24,7 +24,7 @@ import numpy as np ...@@ -24,7 +24,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.io import Dataset, BatchSampler, DataLoader from paddle.io import Dataset, IterableDataset, BatchSampler, DataLoader
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
...@@ -108,6 +108,48 @@ class TestDataLoaderAssert(unittest.TestCase): ...@@ -108,6 +108,48 @@ class TestDataLoaderAssert(unittest.TestCase):
self.assertTrue(False) self.assertTrue(False)
class TestDatasetRuntimeError(unittest.TestCase):
def test_main(self):
dataset = Dataset()
# __getitem__ not implement
try:
d = dataset[0]
self.assertTrue(False)
except NotImplementedError:
pass
# __len__ not implement
try:
l = len(dataset)
self.assertTrue(False)
except NotImplementedError:
pass
dataset = IterableDataset()
# __iter__ not implement
try:
d = iter(dataset)
self.assertTrue(False)
except NotImplementedError:
pass
# __getitem__ runtime error
try:
d = dataset[0]
self.assertTrue(False)
except RuntimeError:
pass
# __len__ runtime error
try:
l = len(dataset)
self.assertTrue(False)
except RuntimeError:
pass
# CI Converage cannot record stub in subprocess, # CI Converage cannot record stub in subprocess,
# HACK a _worker_loop in main process call here # HACK a _worker_loop in main process call here
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -144,12 +186,15 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -144,12 +186,15 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put([i, i + 10]) indices_queue.put([i, i + 10])
indices_queue.put(None) indices_queue.put(None)
loader._worker_loop( loader._worker_loop(
loader._dataset, indices_queue, loader._data_queue, loader._dataset, 0, indices_queue, loader._data_queue,
loader._workers_done_event, _collate_fn, _init_fn, 0) loader._workers_done_event, _collate_fn, _init_fn, 0, 1)
self.assertTrue(False) self.assertTrue(False)
except AssertionError: except AssertionError:
pass pass
except Exception: except Exception as e:
print("Exception", e)
import sys
sys.stdout.flush()
self.assertTrue(False) self.assertTrue(False)
def run_with_worker_done(self, use_shared_memory=True): def run_with_worker_done(self, use_shared_memory=True):
...@@ -184,8 +229,8 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): ...@@ -184,8 +229,8 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put(None) indices_queue.put(None)
loader._workers_done_event.set() loader._workers_done_event.set()
loader._worker_loop( loader._worker_loop(
loader._dataset, indices_queue, loader._data_queue, loader._dataset, 0, indices_queue, loader._data_queue,
loader._workers_done_event, _collate_fn, _init_fn, 0) loader._workers_done_event, _collate_fn, _init_fn, 0, 1)
self.assertTrue(True) self.assertTrue(True)
except AssertionError: except AssertionError:
pass pass
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import sys
import six
import time
import unittest
import multiprocessing
import numpy as np
import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable
from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_iterable_dataset_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM
class SimpleFCNet(fluid.dygraph.Layer):
def __init__(self):
super(SimpleFCNet, self).__init__()
param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.8))
bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5))
self._fcs = []
in_channel = IMAGE_SIZE
for hidden_size in [10, 20, 30]:
self._fcs.append(
Linear(
in_channel,
hidden_size,
act='tanh',
param_attr=param_attr,
bias_attr=bias_attr))
in_channel = hidden_size
self._fcs.append(
Linear(
in_channel,
CLASS_NUM,
act='softmax',
param_attr=param_attr,
bias_attr=bias_attr))
def forward(self, image):
out = image
for fc in self._fcs:
out = fc(out)
return out
class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
fc_net = SimpleFCNet()
optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters())
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True)
step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for image, label in dataloader():
out = fc_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
fc_net.clear_gradients()
loss_list.append(np.mean(avg_loss.numpy()))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(False):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[
0]
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import math
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, BatchSampler, DataLoader, get_worker_info
class RangeIterableDatasetSplit(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(
math.ceil((self.end - self.start) / float(
worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
for i in range(iter_start, iter_end):
yield np.array([i])
class TestDynamicDataLoaderIterSplit(unittest.TestCase):
def test_main(self):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = RangeIterableDatasetSplit(0, 10)
dataloader = DataLoader(
dataset,
places=place,
num_workers=2,
batch_size=1,
drop_last=True)
rets = []
for d in dataloader:
rets.append(d[0].numpy()[0][0])
assert tuple(sorted(rets)) == tuple(range(0, 10))
class RangeIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
for i in range(self.start, self.end):
yield np.array([i])
class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
def test_main(self):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = RangeIterableDataset(0, 10)
def worker_spliter(worker_id):
worker_info = get_worker_info()
dataset = worker_info.dataset
start = dataset.start
end = dataset.end
num_per_worker = int(
math.ceil((end - start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = start + worker_id * num_per_worker
dataset.end = min(dataset.start + num_per_worker, end)
dataloader = DataLoader(
dataset,
places=place,
num_workers=1,
batch_size=1,
drop_last=True,
worker_init_fn=worker_spliter)
rets = []
for d in dataloader:
rets.append(d[0].numpy()[0][0])
assert tuple(sorted(rets)) == tuple(range(0, 10))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import sys
import six
import time
import unittest
import multiprocessing
import numpy as np
import paddle.fluid as fluid
from paddle.io import IterableDataset, BatchSampler, DataLoader, get_worker_info
EPOCH_NUM = 2
BATCH_SIZE = 8
IMAGE_SIZE = 32
SAMPLE_NUM = 80
CLASS_NUM = 10
class RandomDataset(IterableDataset):
def __init__(self, sample_num, class_num):
self.sample_num = sample_num
self.class_num = class_num
def __iter__(self):
for i in range(self.sample_num):
np.random.seed(i)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
yield image, label
def simple_fc_net_static():
startup_prog = fluid.Program()
main_prog = fluid.Program()
startup_prog.random_seed = 1
main_prog.random_seed = 1
with fluid.unique_name.guard():
with fluid.program_guard(main_prog, startup_prog):
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
hidden = image
param_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.8))
bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5))
for hidden_size in [10, 20, 30]:
hidden = fluid.layers.fc(hidden,
size=hidden_size,
act='tanh',
param_attr=param_attr,
bias_attr=bias_attr)
predict_label = fluid.layers.fc(hidden,
size=CLASS_NUM,
act='softmax',
param_attr=param_attr,
bias_attr=bias_attr)
loss = fluid.layers.reduce_mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
return startup_prog, main_prog, image, label, loss
def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
places = []
if with_cpu:
places.append([fluid.CPUPlace()])
if with_data_parallel:
places.append([fluid.CPUPlace()] * 2)
if with_gpu and fluid.core.is_compiled_with_cuda():
tmp = fluid.cuda_places()[:2]
assert len(tmp) > 0, "no gpu detected"
if with_data_parallel:
places.append(tmp)
places.append([tmp[0]])
return places
class TestStaticDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog)
if len(places) > 1:
prog = prog.with_data_parallel(
loss_name=loss.name, places=places)
step_list = []
loss_list = []
start_t = time.time()
for i in six.moves.range(EPOCH_NUM):
step = 0
for d in dataloader:
assert len(d) == len(places), "{} != {}".format(
len(d), len(places))
for i, item in enumerate(d):
image = item['image']
label = item['label']
assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
assert label.shape() == [BATCH_SIZE, 1]
assert image._place()._equals(places[i])
assert label._place()._equals(places[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
use_program_cache=True)
loss_list.append(np.mean(L))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
def test_main(self):
for p in prepare_places(True):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[
0]
if __name__ == '__main__':
unittest.main()
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
# TODO: define all functions about input & output in this directory # TODO: define all functions about input & output in this directory
__all__ = [ __all__ = [
'Dataset', 'Dataset',
'IterableDataset',
'BatchSampler', 'BatchSampler',
# 'Transform', # 'Transform',
'DataLoader', 'DataLoader',
'get_worker_info',
'load', 'load',
'save', 'save',
'load_program_state', 'load_program_state',
...@@ -36,7 +38,7 @@ __all__ = [ ...@@ -36,7 +38,7 @@ __all__ = [
] ]
from ..fluid.io import DataLoader from ..fluid.io import DataLoader
from ..fluid.dataloader import Dataset, BatchSampler from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info
from ..fluid.io import load, save, load_program_state, set_program_state, \ from ..fluid.io import load, save, load_program_state, set_program_state, \
load_inference_model, save_inference_model, batch load_inference_model, save_inference_model, batch
from ..reader import shuffle, buffered, cache, chain, firstn, compose, map_readers, xmap_readers from ..reader import shuffle, buffered, cache, chain, firstn, compose, map_readers, xmap_readers
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册