未验证 提交 89d27de9 编写于 作者: K Kaipeng Deng 提交者: GitHub

DataLoader support not auto collate batch (#28425)

* DataLoader support not auto collate batch. test=develop
上级 c5c273c1
......@@ -36,6 +36,7 @@ from .. import core, layers
from ..framework import in_dygraph_mode
from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
from .batch_sampler import _InfiniteIterableSampler
__all__ = ['get_worker_info']
......@@ -100,11 +101,13 @@ class _DatasetKind(object):
ITER = 1
@staticmethod
def create_fetcher(kind, dataset, collate_fn, drop_last):
def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, drop_last):
if kind == _DatasetKind.MAP:
return _MapDatasetFetcher(dataset, collate_fn, drop_last)
return _MapDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
elif kind == _DatasetKind.ITER:
return _IterableDatasetFetcher(dataset, collate_fn, drop_last)
return _IterableDatasetFetcher(dataset, auto_collate_batch,
collate_fn, drop_last)
else:
raise NotImplementedError("unknown Dataset kind {}".format(kind))
......@@ -221,8 +224,7 @@ class _DataLoaderIterBase(object):
self._places = loader.places
self._return_list = loader.return_list
self._batch_sampler = loader.batch_sampler
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
self._auto_collate_batch = loader.auto_collate_batch
self._num_workers = loader.num_workers
self._use_buffer_reader = loader.use_buffer_reader
self._use_shared_memory = loader.use_shared_memory
......@@ -231,6 +233,16 @@ class _DataLoaderIterBase(object):
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory
if self._auto_collate_batch:
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
else:
if self._dataset_kind == _DatasetKind.MAP:
self._sampler_iter = iter(list(range(len(self._dataset))))
else:
self._sampler_iter = iter(_InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn
# LoDTensorBlockingQueue instance for create_py_reader and a thread
# to put mini-batch data to self._blocking_queue, mini-batch data
# will be get from:
......@@ -257,7 +269,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
super(_DataLoaderIterSingleProcess, self).__init__(loader)
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._collate_fn, True)
self._dataset_kind, self._dataset, self._auto_collate_batch,
self._collate_fn, True)
# NOTE: len(self._places) batch data compose as an output
# iteration, set blocking_queue can cache 2 iteration datas
......@@ -367,7 +380,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
collate_fn, init_fn, worker_id, num_workers,
auto_collate_batch, collate_fn, init_fn, worker_id, num_workers,
use_shared_memory):
try:
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
......@@ -388,7 +401,7 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
collate_fn, True)
auto_collate_batch, collate_fn, True)
except:
init_exception = Exception("init_fn failed in worker {}: " \
"{}".format(worker_id, sys.exc_info()))
......@@ -511,8 +524,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
target=_worker_loop,
args=(self._dataset, self._dataset_kind, indices_queue,
self._data_queue, self._workers_done_event,
self._collate_fn, self._worker_init_fn, i,
self._num_workers, self._use_shared_memory))
self._auto_collate_batch, self._collate_fn,
self._worker_init_fn, i, self._num_workers,
self._use_shared_memory))
worker.daemon = True
worker.start()
self._workers.append(worker)
......
......@@ -14,8 +14,9 @@
class _DatasetFetcher(object):
def __init__(self, dataset, collate_fn, drop_last):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
self.dataset = dataset
self.auto_collate_batch = auto_collate_batch
self.collate_fn = collate_fn
self.drop_last = drop_last
......@@ -25,12 +26,14 @@ class _DatasetFetcher(object):
class _IterableDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, collate_fn,
drop_last)
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch,
collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, batch_indices):
if self.auto_collate_batch:
data = []
for _ in batch_indices:
try:
......@@ -40,14 +43,24 @@ class _IterableDatasetFetcher(_DatasetFetcher):
if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
if self.collate_fn:
data = self.collate_fn(data)
return data
class _MapDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, collate_fn, drop_last)
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last)
def fetch(self, batch_indices):
if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices]
return self.collate_fn(data)
else:
data = self.dataset[batch_indices]
if self.collate_fn:
data = self.collate_fn(data)
return data
......@@ -163,6 +163,21 @@ class DataLoader(object):
For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`
**Disable automatic batching**
In certain cases such as some NLP tasks, instead of automatic batching,
handling batching manually in dataset is needed by users. For these
cases, automatic batching is disabled if both :attr:`batch_size` and
:attr:`batch_sampler` is set as None, each data got from :attr:`dataset`
should be batched data and will be processed with function define by
:attr:`collate_fn` or :attr:`default_collate_fn`.
.. note::
When automatic batching is disabled, :attr:`default_collate_fn` will
do nothing to data from dataset.
Args:
dataset(Dataset): the dataset to load data from, should be an
instance of subclass of :code:`paddle.io.Dataset` or
......@@ -185,7 +200,7 @@ class DataLoader(object):
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
batch_size(int): sample number in a mini-batch, a substitution
batch_size(int|None): sample number in a mini-batch, a substitution
parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
is not set, a default `paddle.io.BatchSampler` will be used
and initialize by :attr:`batch_size`, :attr:`shuffle` and
......@@ -358,10 +373,15 @@ class DataLoader(object):
"batch_size/shuffle/drop_last should not be set when " \
"batch_sampler is given"
self.batch_sampler = batch_sampler
self.batch_size = None
elif batch_size is None:
self.batch_sampler = None
self.batch_size = None
else:
assert batch_size is not None and batch_size > 0, \
"batch_size should be a positive value when " \
assert batch_size > 0, \
"batch_size should be None or a positive value when " \
"batch_sampler is not given"
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(dataset,
batch_size)
......@@ -372,12 +392,20 @@ class DataLoader(object):
shuffle=shuffle,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.pin_memory = False
if in_dygraph_mode():
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()
def __len__(self):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
else:
if self.batch_size is None:
return len(self.dataset)
else:
return len(self.batch_sampler)
def __iter__(self):
......
......@@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable
from test_multiprocess_dataloader_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_static import RandomDataset, RandomBatchedDataset, prepare_places
from test_multiprocess_dataloader_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM
......@@ -122,5 +122,48 @@ class TestDygraphDataLoader(unittest.TestCase):
self.assertLess(diff, 1e-2)
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
fc_net = SimpleFCNet()
optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters())
dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for image, label in dataloader():
out = fc_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
fc_net.clear_gradients()
loss_list.append(np.mean(avg_loss.numpy()))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
if __name__ == '__main__':
unittest.main()
......@@ -188,7 +188,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
indices_queue.put(None)
_worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event,
_collate_fn, _init_fn, 0, 1,
True, _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(False)
except AssertionError:
......@@ -232,7 +232,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase):
loader._workers_done_event.set()
_worker_loop(loader._dataset, 0, indices_queue,
loader._data_queue, loader._workers_done_event,
_collate_fn, _init_fn, 0, 1,
True, _collate_fn, _init_fn, 0, 1,
loader._use_shared_memory)
self.assertTrue(True)
except AssertionError:
......
......@@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable
from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, prepare_places
from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, RandomBatchedDataset, prepare_places
from test_multiprocess_dataloader_iterable_dataset_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM
......@@ -119,5 +119,46 @@ class TestDygraphDataLoader(unittest.TestCase):
0]
class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
fc_net = SimpleFCNet()
optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters())
dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for image, label in dataloader():
out = fc_net(image)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.reduce_mean(loss)
avg_loss.backward()
optimizer.minimize(avg_loss)
fc_net.clear_gradients()
loss_list.append(np.mean(avg_loss.numpy()))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
if __name__ == '__main__':
unittest.main()
......@@ -167,5 +167,80 @@ class TestStaticDataLoader(unittest.TestCase):
0]
class RandomBatchedDataset(IterableDataset):
def __init__(self, sample_num, class_num):
self.sample_num = sample_num // BATCH_SIZE
self.class_num = class_num
def __iter__(self):
for i in range(self.sample_num):
np.random.seed(i)
images = []
labels = []
for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
images.append(image)
labels.append(label)
yield np.stack(images, axis=0), np.stack(labels, axis=0)
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
places=places,
num_workers=num_workers,
batch_size=None,
drop_last=True)
exe = fluid.Executor(place=places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog)
if len(places) > 1:
prog = prog.with_data_parallel(
loss_name=loss.name, places=places)
step_list = []
loss_list = []
start_t = time.time()
for i in six.moves.range(EPOCH_NUM):
step = 0
for d in dataloader:
assert len(d) == len(places), "{} != {}".format(
len(d), len(places))
for i, item in enumerate(d):
image = item['image']
label = item['label']
assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
assert label.shape() == [BATCH_SIZE, 1]
assert image._place()._equals(places[i])
assert label._place()._equals(places[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
use_program_cache=True)
loss_list.append(np.mean(L))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
if __name__ == '__main__':
unittest.main()
......@@ -215,5 +215,82 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert isinstance(d[1], list)
class RandomBatchedDataset(Dataset):
def __init__(self, sample_num, class_num):
self.sample_num = int(sample_num / BATCH_SIZE)
self.class_num = class_num
def __getitem__(self, idx):
np.random.seed(idx)
images = []
labels = []
for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64')
images.append(image)
labels.append(label)
return np.stack(images, axis=0), np.stack(labels, axis=0)
def __len__(self):
return self.sample_num
class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
def run_main(self, num_workers, places):
scope = fluid.Scope()
with fluid.scope_guard(scope):
startup_prog, main_prog, image, label, loss = simple_fc_net_static()
dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
places=places,
num_workers=num_workers,
batch_size=None,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
exe = fluid.Executor(place=places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog)
if len(places) > 1:
prog = prog.with_data_parallel(
loss_name=loss.name, places=places)
step_list = []
loss_list = []
start_t = time.time()
for _ in six.moves.range(EPOCH_NUM):
step = 0
for d in dataloader:
assert len(d) == len(places), "{} != {}".format(
len(d), len(places))
for i, item in enumerate(d):
image = item['image']
label = item['label']
assert image.shape() == [BATCH_SIZE, IMAGE_SIZE]
assert label.shape() == [BATCH_SIZE, 1]
assert image._place()._equals(places[i])
assert label._place()._equals(places[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
use_program_cache=True)
loss_list.append(np.mean(L))
step += 1
step_list.append(step)
end_t = time.time()
ret = {
"time": end_t - start_t,
"step": step_list,
"loss": np.array(loss_list)
}
print("time cost", ret['time'], 'step_list', ret['step'])
return ret
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册