diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index d32a543eb495fa7c1bdcf9124eb38dcd5f0d556c..ee30484ae9a0fb1d56973280f22069bb40906806 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -36,6 +36,7 @@ from .. import core, layers from ..framework import in_dygraph_mode from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher +from .batch_sampler import _InfiniteIterableSampler __all__ = ['get_worker_info'] @@ -100,11 +101,13 @@ class _DatasetKind(object): ITER = 1 @staticmethod - def create_fetcher(kind, dataset, collate_fn, drop_last): + def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, drop_last): if kind == _DatasetKind.MAP: - return _MapDatasetFetcher(dataset, collate_fn, drop_last) + return _MapDatasetFetcher(dataset, auto_collate_batch, + collate_fn, drop_last) elif kind == _DatasetKind.ITER: - return _IterableDatasetFetcher(dataset, collate_fn, drop_last) + return _IterableDatasetFetcher(dataset, auto_collate_batch, + collate_fn, drop_last) else: raise NotImplementedError("unknown Dataset kind {}".format(kind)) @@ -221,8 +224,7 @@ class _DataLoaderIterBase(object): self._places = loader.places self._return_list = loader.return_list self._batch_sampler = loader.batch_sampler - self._sampler_iter = iter(loader.batch_sampler) - self._collate_fn = loader.collate_fn or default_collate_fn + self._auto_collate_batch = loader.auto_collate_batch self._num_workers = loader.num_workers self._use_buffer_reader = loader.use_buffer_reader self._use_shared_memory = loader.use_shared_memory @@ -231,6 +233,16 @@ class _DataLoaderIterBase(object): self._dataset_kind = loader.dataset_kind self._pin_memory = loader.pin_memory + if self._auto_collate_batch: + self._sampler_iter = iter(loader.batch_sampler) + self._collate_fn = loader.collate_fn or default_collate_fn + else: + if self._dataset_kind == _DatasetKind.MAP: + self._sampler_iter = iter(list(range(len(self._dataset)))) + else: + self._sampler_iter = iter(_InfiniteIterableSampler(self._dataset, 1)) + self._collate_fn = loader.collate_fn + # LoDTensorBlockingQueue instance for create_py_reader and a thread # to put mini-batch data to self._blocking_queue, mini-batch data # will be get from: @@ -257,7 +269,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): super(_DataLoaderIterSingleProcess, self).__init__(loader) self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._collate_fn, True) + self._dataset_kind, self._dataset, self._auto_collate_batch, + self._collate_fn, True) # NOTE: len(self._places) batch data compose as an output # iteration, set blocking_queue can cache 2 iteration datas @@ -367,7 +380,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): # NOTE(chenweihang): _worker_loop must be top level method to be pickled def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, - collate_fn, init_fn, worker_id, num_workers, + auto_collate_batch, collate_fn, init_fn, worker_id, num_workers, use_shared_memory): try: # NOTE: [ mmap files clear ] When the child process exits unexpectedly, @@ -388,7 +401,7 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, if init_fn is not None: init_fn(worker_id) fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, - collate_fn, True) + auto_collate_batch, collate_fn, True) except: init_exception = Exception("init_fn failed in worker {}: " \ "{}".format(worker_id, sys.exc_info())) @@ -511,8 +524,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): target=_worker_loop, args=(self._dataset, self._dataset_kind, indices_queue, self._data_queue, self._workers_done_event, - self._collate_fn, self._worker_init_fn, i, - self._num_workers, self._use_shared_memory)) + self._auto_collate_batch, self._collate_fn, + self._worker_init_fn, i, self._num_workers, + self._use_shared_memory)) worker.daemon = True worker.start() self._workers.append(worker) diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 001b8b931da233557c5d0b11af283a6e631788ae..9382a7042237043dba8b41518d23a14ce4a43049 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -14,8 +14,9 @@ class _DatasetFetcher(object): - def __init__(self, dataset, collate_fn, drop_last): + def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): self.dataset = dataset + self.auto_collate_batch = auto_collate_batch self.collate_fn = collate_fn self.drop_last = drop_last @@ -25,29 +26,41 @@ class _DatasetFetcher(object): class _IterableDatasetFetcher(_DatasetFetcher): - def __init__(self, dataset, collate_fn, drop_last): - super(_IterableDatasetFetcher, self).__init__(dataset, collate_fn, - drop_last) + def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): + super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch, + collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, batch_indices): - data = [] - for _ in batch_indices: - try: - data.append(next(self.dataset_iter)) - except StopIteration: - break - if len(data) == 0 or (self.drop_last and - len(data) < len(batch_indices)): - raise StopIteration - return self.collate_fn(data) + if self.auto_collate_batch: + data = [] + for _ in batch_indices: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + break + if len(data) == 0 or (self.drop_last and + len(data) < len(batch_indices)): + raise StopIteration + else: + data = next(self.dataset_iter) + + if self.collate_fn: + data = self.collate_fn(data) + return data class _MapDatasetFetcher(_DatasetFetcher): - def __init__(self, dataset, collate_fn, drop_last): - super(_MapDatasetFetcher, self).__init__(dataset, collate_fn, drop_last) + def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): + super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) def fetch(self, batch_indices): - data = [self.dataset[idx] for idx in batch_indices] - return self.collate_fn(data) + if self.auto_collate_batch: + data = [self.dataset[idx] for idx in batch_indices] + else: + data = self.dataset[batch_indices] + + if self.collate_fn: + data = self.collate_fn(data) + return data diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 0e7fd35f5842e6432b1248bfd8afa351ffe33406..4a50b3bc0c7dc581febdaf12ab70c663c0263377 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -163,6 +163,21 @@ class DataLoader(object): For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler` + **Disable automatic batching** + + In certain cases such as some NLP tasks, instead of automatic batching, + handling batching manually in dataset is needed by users. For these + cases, automatic batching is disabled if both :attr:`batch_size` and + :attr:`batch_sampler` is set as None, each data got from :attr:`dataset` + should be batched data and will be processed with function define by + :attr:`collate_fn` or :attr:`default_collate_fn`. + + + .. note:: + When automatic batching is disabled, :attr:`default_collate_fn` will + do nothing to data from dataset. + + Args: dataset(Dataset): the dataset to load data from, should be an instance of subclass of :code:`paddle.io.Dataset` or @@ -185,7 +200,7 @@ class DataLoader(object): batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` to generate batch indices to draw samples from :attr:`dataset` and combine a batch. Default None. - batch_size(int): sample number in a mini-batch, a substitution + batch_size(int|None): sample number in a mini-batch, a substitution parameter for :attr:`batch_sampler`, if :attr:`batch_sampler` is not set, a default `paddle.io.BatchSampler` will be used and initialize by :attr:`batch_size`, :attr:`shuffle` and @@ -358,10 +373,15 @@ class DataLoader(object): "batch_size/shuffle/drop_last should not be set when " \ "batch_sampler is given" self.batch_sampler = batch_sampler + self.batch_size = None + elif batch_size is None: + self.batch_sampler = None + self.batch_size = None else: - assert batch_size is not None and batch_size > 0, \ - "batch_size should be a positive value when " \ + assert batch_size > 0, \ + "batch_size should be None or a positive value when " \ "batch_sampler is not given" + self.batch_size = batch_size if isinstance(dataset, IterableDataset): self.batch_sampler = _InfiniteIterableSampler(dataset, batch_size) @@ -372,13 +392,21 @@ class DataLoader(object): shuffle=shuffle, drop_last=drop_last) + self.auto_collate_batch = self.batch_sampler is not None + self.pin_memory = False if in_dygraph_mode(): self.pin_memory = True if use_pinned_memory( ) is None else use_pinned_memory() def __len__(self): - return len(self.batch_sampler) + if self.dataset_kind == _DatasetKind.ITER: + raise ValueError("length of IterableDataset not supported") + else: + if self.batch_size is None: + return len(self.dataset) + else: + return len(self.batch_sampler) def __iter__(self): if self.num_workers == 0: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py index 1bb720673e4f33ac7a866cdf73a885741ee08e7e..c89354adf751c67b5bd5214e980789bf0aa62abf 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py @@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.base import to_variable -from test_multiprocess_dataloader_static import RandomDataset, prepare_places +from test_multiprocess_dataloader_static import RandomDataset, RandomBatchedDataset, prepare_places from test_multiprocess_dataloader_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM @@ -122,5 +122,48 @@ class TestDygraphDataLoader(unittest.TestCase): self.assertLess(diff, 1e-2) +class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): + def run_main(self, num_workers, places): + fluid.default_startup_program().random_seed = 1 + fluid.default_main_program().random_seed = 1 + with fluid.dygraph.guard(places[0]): + fc_net = SimpleFCNet() + optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters()) + + dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + num_workers=num_workers, + batch_size=None, + drop_last=True) + assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for image, label in dataloader(): + out = fc_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + optimizer.minimize(avg_loss) + fc_net.clear_gradients() + + loss_list.append(np.mean(avg_loss.numpy())) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py index 6fd14b40bc9108b6075a0ac1f40cbefd79b8f0d9..74fe359cd7d597def5dd62e973abf001af05b81c 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_exception.py @@ -188,7 +188,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): indices_queue.put(None) _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, - _collate_fn, _init_fn, 0, 1, + True, _collate_fn, _init_fn, 0, 1, loader._use_shared_memory) self.assertTrue(False) except AssertionError: @@ -232,7 +232,7 @@ class TestDataLoaderWorkerLoop(unittest.TestCase): loader._workers_done_event.set() _worker_loop(loader._dataset, 0, indices_queue, loader._data_queue, loader._workers_done_event, - _collate_fn, _init_fn, 0, 1, + True, _collate_fn, _init_fn, 0, 1, loader._use_shared_memory) self.assertTrue(True) except AssertionError: diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py index af332d8e43209251c4d3751f2689266f0fcd6c1e..0533a0d09fa0de2e71360dcc92d3c3db52427f83 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py @@ -27,7 +27,7 @@ from paddle.io import Dataset, BatchSampler, DataLoader from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.base import to_variable -from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, prepare_places +from test_multiprocess_dataloader_iterable_dataset_static import RandomDataset, RandomBatchedDataset, prepare_places from test_multiprocess_dataloader_iterable_dataset_static import EPOCH_NUM, BATCH_SIZE, IMAGE_SIZE, SAMPLE_NUM, CLASS_NUM @@ -119,5 +119,46 @@ class TestDygraphDataLoader(unittest.TestCase): 0] +class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): + def run_main(self, num_workers, places): + fluid.default_startup_program().random_seed = 1 + fluid.default_main_program().random_seed = 1 + with fluid.dygraph.guard(places[0]): + fc_net = SimpleFCNet() + optimizer = fluid.optimizer.Adam(parameter_list=fc_net.parameters()) + + dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + num_workers=num_workers, + batch_size=None, + drop_last=True) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for image, label in dataloader(): + out = fc_net(image) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.reduce_mean(loss) + avg_loss.backward() + optimizer.minimize(avg_loss) + fc_net.clear_gradients() + + loss_list.append(np.mean(avg_loss.numpy())) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py index e64e11d156ec74a375c161926ce3671e83f2352a..4615bf85ce69f45cf6d852cdb17e5dff1d533979 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py @@ -167,5 +167,80 @@ class TestStaticDataLoader(unittest.TestCase): 0] +class RandomBatchedDataset(IterableDataset): + def __init__(self, sample_num, class_num): + self.sample_num = sample_num // BATCH_SIZE + self.class_num = class_num + + def __iter__(self): + for i in range(self.sample_num): + np.random.seed(i) + images = [] + labels = [] + for _ in range(BATCH_SIZE): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, self.class_num - 1, + (1, )).astype('int64') + images.append(image) + labels.append(label) + yield np.stack(images, axis=0), np.stack(labels, axis=0) + + +class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): + def run_main(self, num_workers, places): + scope = fluid.Scope() + with fluid.scope_guard(scope): + startup_prog, main_prog, image, label, loss = simple_fc_net_static() + + dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + places=places, + num_workers=num_workers, + batch_size=None, + drop_last=True) + + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if len(places) > 1: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step_list = [] + loss_list = [] + start_t = time.time() + for i in six.moves.range(EPOCH_NUM): + step = 0 + for d in dataloader: + assert len(d) == len(places), "{} != {}".format( + len(d), len(places)) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, IMAGE_SIZE] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(places[i]) + assert label._place()._equals(places[i]) + L, = exe.run(program=prog, + feed=d, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index c01e2e75b8195c0b2f2c46a6d18969055a68f977..5ec907c290b946dc79ee6f865c42fd8094afbb35 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -215,5 +215,82 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): assert isinstance(d[1], list) +class RandomBatchedDataset(Dataset): + def __init__(self, sample_num, class_num): + self.sample_num = int(sample_num / BATCH_SIZE) + self.class_num = class_num + + def __getitem__(self, idx): + np.random.seed(idx) + images = [] + labels = [] + for _ in range(BATCH_SIZE): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') + images.append(image) + labels.append(label) + return np.stack(images, axis=0), np.stack(labels, axis=0) + + def __len__(self): + return self.sample_num + + +class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): + def run_main(self, num_workers, places): + scope = fluid.Scope() + with fluid.scope_guard(scope): + startup_prog, main_prog, image, label, loss = simple_fc_net_static() + + dataset = RandomBatchedDataset(SAMPLE_NUM, CLASS_NUM) + dataloader = DataLoader( + dataset, + feed_list=[image, label], + places=places, + num_workers=num_workers, + batch_size=None, + drop_last=True) + assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) + + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if len(places) > 1: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step_list = [] + loss_list = [] + start_t = time.time() + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for d in dataloader: + assert len(d) == len(places), "{} != {}".format( + len(d), len(places)) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, IMAGE_SIZE] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(places[i]) + assert label._place()._equals(places[i]) + L, = exe.run(program=prog, + feed=d, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + step_list.append(step) + + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + print("time cost", ret['time'], 'step_list', ret['step']) + return ret + + if __name__ == '__main__': unittest.main()