From 0436efd6a303611a18c0dee921fd46440016226f Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 23 Sep 2019 08:35:09 +0800 Subject: [PATCH] Unify DataLoader APIs (#19305) * unify DataLoader APIs, test=develop * integrate iterable CPU Dataset, test=develop add GPU dataset supporting, test=develop * add unittests for dataset, test=develop * add more docs to dataloader apis, test=develop, test=document_preview * refine doc, test=develop * refine doc again, test=develop * increase coverage, test=develop --- paddle/fluid/API.spec | 9 +- paddle/fluid/pybind/data_set_py.cc | 153 ++- paddle/fluid/pybind/reader_py.cc | 13 +- python/paddle/fluid/reader.py | 872 ++++++++++++------ .../fluid/tests/unittests/CMakeLists.txt | 6 + .../fluid/tests/unittests/simple_nets.py | 13 +- .../fluid/tests/unittests/test_dataset.py | 96 +- .../unittests/test_dataset_dataloader.py | 221 +++++ .../test_decoupled_py_reader_data_check.py | 84 +- .../unittests/test_generator_dataloader.py | 196 ++++ ...eader.py => test_py_reader_return_list.py} | 11 +- 11 files changed, 1334 insertions(+), 340 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dataset_dataloader.py create mode 100644 python/paddle/fluid/tests/unittests/test_generator_dataloader.py rename python/paddle/fluid/tests/unittests/{test_pyreader.py => test_py_reader_return_list.py} (87%) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 687e9b01456..9607d0026d0 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -68,13 +68,18 @@ paddle.fluid.io.load_persistables (ArgSpec(args=['executor', 'dirname', 'main_pr paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment', 'program_only'], varargs=None, keywords=None, defaults=(None, None, None, True, False)), ('document', 'fc82bfd137a9b1ab8ebd1651bd35b6e5')) paddle.fluid.io.load_inference_model (ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '2f54d7c206b62f8c10f4f9d78c731cfd')) paddle.fluid.io.batch (ArgSpec(args=['reader', 'batch_size', 'drop_last'], varargs=None, keywords=None, defaults=(False,)), ('document', 'cf2869b408b39cadadd95206b4e03b39')) -paddle.fluid.io.PyReader ('paddle.fluid.reader.PyReader', ('document', 'e37efae53f3935b32aec37eda9f3d906')) +paddle.fluid.io.PyReader ('paddle.fluid.reader.PyReader', ('document', 'b03399246f69cd6fc03b43e87af8bd4e')) paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable', 'return_list'], varargs=None, keywords=None, defaults=(None, None, True, True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.io.PyReader.decorate_batch_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '4364e836e3cb8ab5e68e411b763c50c7')) paddle.fluid.io.PyReader.decorate_sample_generator (ArgSpec(args=['self', 'sample_generator', 'batch_size', 'drop_last', 'places'], varargs=None, keywords=None, defaults=(True, None)), ('document', 'efa4c8b90fe6d99dcbda637b70351bb1')) paddle.fluid.io.PyReader.decorate_sample_list_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '6c11980092720de304863de98074a64a')) +paddle.fluid.io.PyReader.next (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '08b2fd1463f3ea99d79d17303988349b')) paddle.fluid.io.PyReader.reset (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '7432197701fdaab1848063860dc0b97e')) -paddle.fluid.io.PyReader.start (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f6395fd95b025000c5c7a5be31aebc4e')) +paddle.fluid.io.PyReader.start (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'a0983fb21a0a51e6a31716009fe9a9c1')) +paddle.fluid.io.DataLoader ('paddle.fluid.reader.DataLoader', ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.io.DataLoader.__init__ +paddle.fluid.io.DataLoader.from_dataset (ArgSpec(args=['dataset', 'places', 'drop_last'], varargs=None, keywords=None, defaults=(True,)), ('document', '58e8bffa033f26b00b256c8bb1daff11')) +paddle.fluid.io.DataLoader.from_generator (ArgSpec(args=['feed_list', 'capacity', 'use_double_buffer', 'iterable', 'return_list'], varargs=None, keywords=None, defaults=(None, None, True, True, False)), ('document', '8034bdb488fa18d60c4ffb0ba9658337')) paddle.fluid.io.cache (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '1676886070eb607cb608f7ba47be0d3c')) paddle.fluid.io.map_readers (ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None), ('document', '77cbadb09df588e21e5cc0819b69c87d')) paddle.fluid.io.buffered (ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None), ('document', '0d6186f109feceb99f60ec50a0a624cb')) diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 439824f768c..0125465e6e2 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -21,6 +21,8 @@ limitations under the License. */ #endif #include #include +#include +#include #include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" @@ -41,10 +43,147 @@ namespace pd = paddle::framework; namespace paddle { namespace pybind { -void BindDataset(py::module* m) { +class IterableDatasetWrapper { + public: + IterableDatasetWrapper(framework::Dataset *dataset, + const std::vector &slots, + const std::vector &places, + size_t batch_size, bool drop_last) + : dataset_(dataset), + slots_(slots), + places_(places), + batch_size_(batch_size), + drop_last_(drop_last) { +#if defined _WIN32 + PADDLE_THROW("Dataset is not supported on Windows"); +#elif defined __APPLE__ + PADDLE_THROW("Dataset is not supported on MAC"); +#else + size_t device_num = places_.size(); + PADDLE_ENFORCE_GT(device_num, 0, "thread_num must be larger than 0"); + PADDLE_ENFORCE_GT(slots_.size(), 0, "slot_num cannot be 0"); + scopes_.reserve(device_num); + tensors_.reserve(device_num); + for (size_t i = 0; i < device_num; ++i) { + scopes_.emplace_back(new framework::Scope()); + tensors_.emplace_back(); + for (auto &var_name : slots_) { + auto *var = scopes_.back()->Var(var_name); + auto *t = var->GetMutable(); + tensors_.back().emplace_back(t); + } + } + + is_exhaustive_.resize(device_num); + exhaustive_num_ = 0; +#endif + } + + void Start() { + PADDLE_ENFORCE_EQ(is_started_, false, "Reader has been started"); + data_feeds_ = dataset_->GetReaders(); + PADDLE_ENFORCE_EQ(data_feeds_.size(), places_.size(), + "Device number does not match reader number"); + for (size_t i = 0; i < places_.size(); ++i) { + data_feeds_[i]->AssignFeedVar(*scopes_[i]); + data_feeds_[i]->SetPlace(platform::CPUPlace()); + PADDLE_ENFORCE_EQ(data_feeds_[i]->Start(), true, "Reader start failed"); + } + is_started_ = true; + + is_exhaustive_.assign(places_.size(), false); + exhaustive_num_ = 0; + } + + std::vector> Next() { + PADDLE_ENFORCE_EQ(is_started_, true, "Reader must be started"); + size_t device_num = places_.size(); + + std::vector> result( + device_num); + + size_t read_num = 0; + while (read_num < device_num && exhaustive_num_ < device_num) { + for (size_t i = 0; i < data_feeds_.size(); ++i) { + if (is_exhaustive_[i]) { + continue; + } + + bool is_success = (data_feeds_[i]->Next() > 0); + if (!is_success) { + is_exhaustive_[i] = true; + ++exhaustive_num_; + continue; + } + + for (size_t j = 0; j < slots_.size(); ++j) { + if (!IsValidLoDTensor(*tensors_[i][j])) { + is_success = false; + break; + } + + if (tensors_[i][j]->place() == places_[read_num]) { + result[read_num].emplace(slots_[j], std::move(*tensors_[i][j])); + } else { + framework::TensorCopy(std::move(*tensors_[i][j]), places_[read_num], + &result[read_num][slots_[j]]); + } + } + + if (!is_success) { + is_exhaustive_[i] = true; + ++exhaustive_num_; + continue; + } + + ++read_num; + if (read_num == device_num) { + break; + } + } + } + + if (UNLIKELY(read_num != device_num)) { + is_started_ = false; + throw py::stop_iteration(); + } + + return result; + } + + private: + bool IsValidLoDTensor(const framework::LoDTensor &tensor) const { + auto &lod = tensor.lod(); + PADDLE_ENFORCE_LE(lod.size(), 1, "lod level must be not larger than 1"); + if (!drop_last_) return true; + + if (lod.empty()) { + return static_cast(tensor.dims()[0]) == batch_size_; + } else { + return lod[0].size() == batch_size_ + 1; + } + } + + private: + framework::Dataset *dataset_; + std::vector slots_; + std::vector places_; + size_t batch_size_; + bool drop_last_; + + std::vector data_feeds_; + std::vector is_exhaustive_; + size_t exhaustive_num_; + + std::vector> scopes_; + std::vector> tensors_; + bool is_started_{false}; +}; + +void BindDataset(py::module *m) { py::class_>(*m, "Dataset") - .def(py::init([](const std::string& name = "MultiSlotDataset") { + .def(py::init([](const std::string &name = "MultiSlotDataset") { return framework::DatasetFactory::CreateDataset(name); })) .def("set_filelist", &framework::Dataset::SetFileList, @@ -119,7 +258,13 @@ void BindDataset(py::module* m) { .def("destroy_preload_readers", &framework::Dataset::DestroyPreLoadReaders, py::call_guard()); + + py::class_(*m, "IterableDatasetWrapper") + .def(py::init &, + const std::vector &, size_t, bool>()) + .def("_start", &IterableDatasetWrapper::Start) + .def("_next", &IterableDatasetWrapper::Next); } -} // end namespace pybind -} // end namespace paddle +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc index 4c304e8626b..4009bcf2a8b 100644 --- a/paddle/fluid/pybind/reader_py.cc +++ b/paddle/fluid/pybind/reader_py.cc @@ -18,6 +18,7 @@ #include #include #include +#include "Python.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/py_reader.h" @@ -27,6 +28,14 @@ namespace paddle { namespace pybind { +namespace py = pybind11; + +static void RaiseStopIterationException() { + VLOG(2) << "Raise StopIteration Exception in Python"; + py::gil_scoped_acquire guard; + throw py::stop_iteration(); +} + class MultiDeviceFeedReader { public: using ResultDictList = @@ -69,6 +78,7 @@ class MultiDeviceFeedReader { bool success = WaitFutures(); if (!success) { + RaiseStopIterationException(); return {}; } @@ -85,6 +95,7 @@ class MultiDeviceFeedReader { ResultList ReadNextList() { bool success = WaitFutures(); if (!success) { + RaiseStopIterationException(); return {}; } @@ -144,8 +155,6 @@ class MultiDeviceFeedReader { std::vector> ret_; }; -namespace py = pybind11; - void BindReader(py::module *module) { auto &m = *module; diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index d136b72f83a..f10a0ed5548 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -19,14 +19,17 @@ import warnings import numpy as np import threading import paddle -from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode +from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer from .unique_name import UniqueNameGenerator import logging +from .dataset import DatasetBase, InMemoryDataset -__all__ = ['PyReader'] +__all__ = ['PyReader', 'DataLoader'] + +data_loader_unique_name_generator = UniqueNameGenerator() def _convert_places(places): @@ -44,151 +47,255 @@ def _convert_places(places): return ret -class PyReader(object): - """ - Create a reader object for data feeding in Python. - Data would be prefetched using Python thread and be pushed - into a queue asynchronously. Data in the queue would be extracted - automatically when `Executor.run(...)` is called. +class DataLoaderBase(object): + def __init__(self): + self._places = None - Args: - feed_list (list(Variable)|tuple(Variable)): feed variable list. - The variables should be created by :code:`fluid.layers.data()`. - it can be None under iterable mode. - capacity (int): capacity of the queue maintained in PyReader object. - use_double_buffer (bool): whether to use double_buffer_reader to - speed up data feeding. - iterable (bool): whether the created reader object is iterable. - return_list (bool): whether the return value presented as list. - Returns: - reader (Reader): the created reader object. + def __call__(self): + return self - Examples: - 1. If iterable = False, the created PyReader object is almost the - same as :code:`fluid.layers.py_reader()`. Operators would be - inserted into the program. User should call :code:`start()` - before each epoch and catch :code:`fluid.core.EOFException` - thrown by :code:`Executor.run()` when epoch ends. Once the - exception is caught, user should call :code:`reset()` to reset - the reader manually. + def next(self): + ''' + Get the next item in the DataLoader object. This method + should not be called by users directly. It is used for + implementing iterator protocol of Python 2.x inside + PaddlePaddle framework. + ''' + return self.__next__() + + def __iter__(self): + raise NotImplementedError() + + def __next__(self): + raise NotImplementedError() + + +class DataLoader(object): + @staticmethod + def from_generator(feed_list=None, + capacity=None, + use_double_buffer=True, + iterable=True, + return_list=False): + """ + Create a DataLoader object for loading data from Python generator. + Data would be prefetched using Python thread and be pushed + into a queue asynchronously. + + The created DataLoader object provides 3 methods to set the data source + :code:`set_sample_generator` , :code:`set_sample_list_generator` and + :code:`set_batch_generator` . Please see the following example codes + to know their usages. + + If iterable = True, the created DataLoader object is a Python generator + object, which is iterable using for-range loop. + + If iterable = False, the created DataLoader object provides + :code:`start()` and :code:`reset()` method to control the data reading + process. This mode is designed to be compatible with the + :code:`fluid.layers.py_reader` interface. Users can migrate the codes + from :code:`fluid.layers.py_reader` to :code:`fluid.io.DataLoader` + easily when using iterable=False. + + Args: + feed_list (list(Variable)|tuple(Variable)): feed variable list. + The variables should be created by :code:`fluid.layers.data()`. + capacity (int): capacity of the queue maintained in DataLoader. + The unit is batch number. Set larger capacity if your reader + is fast. + use_double_buffer (bool): whether to use double_buffer_reader. + If use_double_buffer=True, the DataLoader would prefetch next + batch data asynchronously, so it would speed up data feeding + and occupies a little more CPU or GPU memory, i.e., the memory + of one batch input data. + iterable (bool): whether the created DataLoader is iterable. + return_list (bool): whether the return value on each device is + presented as a list. It is only valid when iterable=True. + If return_list=False, the return value on each device would + be a dict of str -> LoDTensor, where the key of the dict is + the name of each feeded variables. If return_list=True, the + return value on each device would be a list(LoDTensor). It is + recommended to use return_list=False in static graph mode and + use return_list=True in dygraph mode. + + Returns: + loader (DataLoader): the created DataLoader object. + + Examples: + + .. code-block:: python - .. code-block:: python + import paddle.fluid as fluid + import numpy as np - import paddle - import paddle.fluid as fluid - import numpy as np + BATCH_NUM = 10 + BATCH_SIZE = 16 + EPOCH_NUM = 4 + + CLASS_NUM = 10 + + ITERABLE = True # whether the created DataLoader object is iterable + USE_GPU = False # whether to use GPU + + DATA_FORMAT = 'batch_generator' # data format of data source user provides + + def simple_net(image, label): + fc_tmp = fluid.layers.fc(image, size=CLASS_NUM) + cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label) + loss = fluid.layers.reduce_mean(cross_entropy) + sgd = fluid.optimizer.SGD(learning_rate=1e-3) + sgd.minimize(loss) + return loss + + def get_random_images_and_labels(image_shape, label_shape): + image = np.random.random(size=image_shape).astype('float32') + label = np.random.random(size=label_shape).astype('int64') + return image, label + + # If the data generator yields one sample each time, + # use DataLoader.set_sample_generator to set the data source. + def sample_generator_creator(): + def __reader__(): + for _ in range(BATCH_NUM * BATCH_SIZE): + image, label = get_random_images_and_labels([784], [1]) + yield image, label + + return __reader__ + + # If the data generator yield list of samples each time, + # use DataLoader.set_sample_list_generator to set the data source. + def sample_list_generator_creator(): + def __reader__(): + for _ in range(BATCH_NUM): + sample_list = [] + for _ in range(BATCH_SIZE): + image, label = get_random_images_and_labels([784], [1]) + sample_list.append([image, label]) + + yield sample_list + + return __reader__ + + # If the data generator yields a batch each time, + # use DataLoader.set_batch_generator to set the data source. + def batch_generator_creator(): + def __reader__(): + for _ in range(BATCH_NUM): + batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1]) + yield batch_image, batch_label - EPOCH_NUM = 3 - ITER_NUM = 5 - BATCH_SIZE = 3 + return __reader__ - def reader_creator_random_image_and_label(height, width): - def reader(): - for i in range(ITER_NUM): - fake_image = np.random.uniform(low=0, - high=255, - size=[height, width]) - fake_label = np.ones([1]) - yield fake_image, fake_label - return reader + # If DataLoader is iterable, use for loop to train the network + def train_iterable(exe, prog, loss, loader): + for _ in range(EPOCH_NUM): + for data in loader(): + exe.run(prog, feed=data, fetch_list=[loss]) - image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # If DataLoader is not iterable, use start() and reset() method to control the process + def train_non_iterable(exe, prog, loss, loader): + for _ in range(EPOCH_NUM): + loader.start() # call DataLoader.start() before each epoch starts + try: + while True: + exe.run(prog, fetch_list=[loss]) + except fluid.core.EOFException: + loader.reset() # call DataLoader.reset() after catching EOFException + + def set_data_source(loader, places): + if DATA_FORMAT == 'sample_generator': + loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places) + elif DATA_FORMAT == 'sample_list_generator': + loader.set_sample_list_generator(sample_list_generator_creator(), places=places) + elif DATA_FORMAT == 'batch_generator': + loader.set_batch_generator(batch_generator_creator(), places=places) + else: + raise ValueError('Unsupported data format') - reader = fluid.io.PyReader(feed_list=[image, label], - capacity=4, - iterable=False) + image = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') - user_defined_reader = reader_creator_random_image_and_label(784, 784) - reader.decorate_sample_list_generator( - paddle.batch(user_defined_reader, batch_size=BATCH_SIZE)) - # definition of network is omitted - executor = fluid.Executor(fluid.CUDAPlace(0)) - executor.run(fluid.default_startup_program()) - for i in range(EPOCH_NUM): - reader.start() - while True: - try: - executor.run(feed=None) - except fluid.core.EOFException: - reader.reset() - break + # Define DataLoader + loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE) - - 2. If iterable=True, the created PyReader object is decoupled with - the program. No operator would be inserted into the program. - In this case, the created reader is a Python generator, which - is iterable. User should feed the data yielded from PyReader - object into :code:`Executor.run(feed=...)`. + # Define network + loss = simple_net(image, label) - .. code-block:: python + # Set data source of DataLoader + # + # If DataLoader is iterable, places must be given and the number of places must be the same with device number. + # - If you are using GPU, call `fluid.cuda_places()` to get all GPU places. + # - If you are using CPU, call `fluid.cpu_places()` to get all CPU places. + # + # If DataLoader is not iterable, places can be None. + places = fluid.cuda_places() if USE_GPU else fluid.cpu_places() + set_data_source(loader, places) - import paddle - import paddle.fluid as fluid - import numpy as np + exe = fluid.Executor(places[0]) + exe.run(fluid.default_startup_program()) - EPOCH_NUM = 3 - ITER_NUM = 5 - BATCH_SIZE = 10 + prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - def reader_creator_random_image(height, width): - def reader(): - for i in range(ITER_NUM): - yield np.random.uniform(low=0, high=255, size=[height, width]), - return reader + if loader.iterable: + train_iterable(exe, prog, loss, loader) + else: + train_non_iterable(exe, prog, loss, loader) + + + ''' + Users can use return_list = True in dygraph mode. + ''' + with fluid.dygraph.guard(places[0]): + loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True) + set_data_source(loader, places[0]) + for image, label in loader(): + relu = fluid.layers.relu(image) + assert image.shape == [BATCH_SIZE, 784] + assert label.shape == [BATCH_SIZE, 1] + assert relu.shape == [BATCH_SIZE, 784] + """ + return GeneratorLoader(feed_list, capacity, use_double_buffer, iterable, + return_list) + + @staticmethod + def from_dataset(dataset, places, drop_last=True): + """ + Create an iterable DataLoader object for loading data from Dataset. + Dataset is only supported in Linux system currently. - image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') - reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=False) + Args: + dataset (InMemoryDataset|QueueDataset): the dataset object. + places (list(CUDAPlace)|list(CPUPlace)): places where the result + data should be converted. + drop_last (bool): whether to drop the last batch whose sample + number is less than batch size. If drop_last = True, they + would be dropped. If drop_last = False, they would be kept. - user_defined_reader = reader_creator_random_image(784, 784) - reader.decorate_sample_list_generator( - paddle.batch(user_defined_reader, batch_size=BATCH_SIZE), - fluid.core.CUDAPlace(0)) - # definition of network is omitted - executor = fluid.Executor(fluid.CUDAPlace(0)) - executor.run(fluid.default_main_program()) + Returns: + loader (DataLoader): the created DataLoader object, which can be + treated as a Python generator. - for _ in range(EPOCH_NUM): - for data in reader(): - executor.run(feed=data) + Examples: + + .. code-block:: python + import paddle.fluid as fluid - 3. If return_list=True, the return values would be presented as list instead of dict`. + image = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') - .. code-block:: python + dataset = fluid.DatasetFactory().create_dataset("QueueDataset") + dataset.set_batch_size(32) + dataset.set_filelist(['a.txt', 'b.txt', 'c.txt']) + dataset.set_use_var([image, label]) + dataset.set_pipe_command('cat') - import paddle - import paddle.fluid as fluid - import numpy as np - - EPOCH_NUM = 3 - ITER_NUM = 5 - BATCH_SIZE = 10 - - def reader_creator_random_image(height, width): - def reader(): - for i in range(ITER_NUM): - yield np.random.uniform(low=0, high=255, size=[height, width]), - return reader - - image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') - reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=True) - - user_defined_reader = reader_creator_random_image(784, 784) - reader.decorate_sample_list_generator( - paddle.batch(user_defined_reader, batch_size=BATCH_SIZE), - fluid.core.CPUPlace()) - # definition of network is omitted - executor = fluid.Executor(fluid.core.CPUPlace()) - executor.run(fluid.default_main_program()) - - for _ in range(EPOCH_NUM): - for data in reader(): - executor.run(feed={"image": data[0]}) - """ + loader = fluid.io.DataLoader.from_dataset(dataset, fluid.cpu_places()) + """ + return DatasetLoader(dataset, places, drop_last) - unique_name_generator = UniqueNameGenerator() +class GeneratorLoader(DataLoaderBase): def __init__(self, feed_list=None, capacity=None, @@ -196,6 +303,7 @@ class PyReader(object): iterable=True, return_list=False): self._tensor_reader = None + self._places = None self._thread = None self._feed_list = feed_list if not capacity: @@ -204,11 +312,13 @@ class PyReader(object): if in_dygraph_mode(): if not iterable: warnings.warn( - "Please NOTE: dygraph can support iterable mode only.") + "Please NOTE: dygraph can support iterable mode only. Change to iterable mode." + ) self._iterable = True if not return_list: warnings.warn( - "Please NOTE: dygraph can support return as list only.") + "Please NOTE: dygraph can support return as list only. Change to return as list." + ) self._return_list = True else: self._iterable = iterable @@ -220,12 +330,20 @@ class PyReader(object): if not self._iterable: self._init_non_iterable() - def _init_iterable(self, places): + def _wait_thread_ends(self): + # Get self._thread first to prevent data race, because __thread_main__ + # would set self._thread be None at the end + thread = self._thread + if thread is not None and self._iterable: + self._queue.close() + thread.join() + + def _init_iterable(self): + self._wait_thread_ends() if in_dygraph_mode(): self._var_names = [] else: self._var_names = [v.name for v in self._feed_list] - self._places = _convert_places(places) self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), self._capacity) self._reader = core.create_py_reader( @@ -245,9 +363,10 @@ class PyReader(object): shapes.append(feed_data.shape) lod_levels.append(feed_data.lod_level) - queue_name = PyReader.unique_name_generator('lod_tensor_blocking_queue') - reader_name = PyReader.unique_name_generator('create_py_reader') - double_buffer_name = PyReader.unique_name_generator('double_buffer') + queue_name = data_loader_unique_name_generator( + 'lod_tensor_blocking_queue') + reader_name = data_loader_unique_name_generator('create_py_reader') + double_buffer_name = data_loader_unique_name_generator('double_buffer') var = global_scope().var(queue_name) self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity) @@ -298,62 +417,325 @@ class PyReader(object): def iterable(self): return self._iterable - def __call__(self): - assert self.iterable, "PyReader is not iterable" + def __iter__(self): + assert self.iterable, "DataLoader is not iterable" assert self._tensor_reader is not None, \ - "Data source of PyReader has not set yet" - - class Iterator(object): - def __init__(self, reader): - self._reader = reader._reader - self._reset = reader._reset - self._return_list = reader._return_list - - def __iter__(self): - return self - - def __next__(self): - return self.next() - - def next(self): - if not in_dygraph_mode(): - if self._return_list: - ret = self._reader.read_next_list() - ret = ret[0] if ret is not None and len( - ret) > 0 else None - else: - ret = self._reader.read_next() - if ret: - return ret - else: - self._reset() - raise StopIteration - else: - ret = self._reader.read_next_list() - if ret and ret[0]: - return [ - dygraph.base.to_variable(np.array(v)) - for v in ret[0] - ] - else: - self._reset() - raise StopIteration + "Data source of DataLoader has not set yet" + self._init_iterable() self._start() - return Iterator(self) + return self + + def __next__(self): + try: + if not in_dygraph_mode(): + if self._return_list: + return self._reader.read_next_list() + else: + return self._reader.read_next() + else: + ret = self._reader.read_next_list()[0] + return [dygraph.base.to_variable(np.array(v)) for v in ret] + except StopIteration: + self._queue.close() + self._reset() + six.reraise(*sys.exc_info()) + + def start(self): + if not in_dygraph_mode(): + assert not self._iterable, "start() cannot be called when DataLoader is iterable" + self._start() + + def reset(self): + if not in_dygraph_mode(): + assert not self._iterable, "reset() cannot be called when DataLoader is iterable" + self._reset() + + def _start(self): + def __thread_main__(): + try: + for tensors in self._tensor_reader(): + array = core.LoDTensorArray() + for item in tensors: + if not isinstance(item, core.LoDTensor): + tmp = core.LoDTensor() + tmp.set(item, core.CPUPlace()) + item = tmp + + array.append(item) + + if not self._queue.push(array): + break + + self._queue.close() + self._thread = None + except Exception as ex: + self._queue.close() + self._thread = None + logging.warn('Your reader has raised an exception!') + six.reraise(*sys.exc_info()) + + self._thread = threading.Thread(target=__thread_main__) + self._thread.daemon = True + self._thread.start() def _reset(self): self._reader.reset() - self._thread.join() + thread = self._thread + if thread is not None: + thread.join() + + def set_sample_generator(self, + reader, + batch_size, + drop_last=True, + places=None): + assert batch_size > 0, "batch_size must be larger than 0" + if not in_dygraph_mode(): + has_lod = False + for f in self._feed_list: + if f.lod_level != 0: + has_lod = True + break + + if has_lod: + self.set_sample_list_generator( + paddle.batch( + reader, batch_size=batch_size, drop_last=drop_last), + places=places) + else: + reader = BatchedTensorProvider( + feed_list=self._feed_list, + place=core.CPUPlace(), + batch_size=batch_size, + generator=reader, + drop_last=drop_last) + self.set_batch_generator(reader, places=places) + else: + self.set_sample_list_generator( + paddle.batch( + reader, batch_size=batch_size, drop_last=drop_last), + places=places) + return self + + def set_sample_list_generator(self, reader, places=None): + if not in_dygraph_mode(): + with program_guard(Program(), Program()): + feeder = DataFeeder( + feed_list=self._feed_list, place=core.CPUPlace()) + paddle_reader = feeder.decorate_reader( + reader, multi_devices=False) + + def __tensor_reader_impl__(): + for slots in paddle_reader(): + yield [slots[var.name] for var in self._feed_list] + else: + provider = ListTensorProvider(reader, places) + + def __tensor_reader_impl__(): + for slots in provider(): + yield slots[0] + + self.set_batch_generator(__tensor_reader_impl__, places) + return self + + def set_batch_generator(self, reader, places=None): + self._tensor_reader = reader + if self._iterable: + assert places is not None, "Places cannot be None when DataLoader is iterable" + self._places = _convert_places(places) + if in_dygraph_mode(): + assert len(self._places + ) == 1, "Number of places must be 1 in dygraph mode" + else: + if places is not None: + logging.info( + 'places would be ommited when DataLoader is not iterable') + return self + + +class PyReader(DataLoaderBase): + """ + Create a reader object for data feeding in Python. + Data would be prefetched using Python thread and be pushed + into a queue asynchronously. Data in the queue would be extracted + automatically when `Executor.run(...)` is called. + + Args: + feed_list (list(Variable)|tuple(Variable)): feed variable list. + The variables should be created by :code:`fluid.layers.data()`. + capacity (int): capacity of the queue maintained in PyReader. + The unit is batch number. Set larger capacity if your reader + is fast. + use_double_buffer (bool): whether to use double_buffer_reader. + If use_double_buffer=True, PyReader would prefetch next + batch data asynchronously, so it would speed up data feeding + and occupies a little more CPU or GPU memory, i.e., the memory + of one batch input data. + iterable (bool): whether the created PyReader is iterable. + return_list (bool): whether the return value on each device is + presented as a list. It is only valid when iterable=True. + If return_list=False, the return value on each device would + be a dict of str -> LoDTensor, where the key of the dict is + the name of each feeded variables. If return_list=True, the + return value on each device would be a list(LoDTensor). It is + recommended to use return_list=False in static graph mode and + use return_list=True in dygraph mode. + + Returns: + reader (Reader): the created reader object. + + Examples: + 1. If iterable = False, the created PyReader object is almost the + same as :code:`fluid.layers.py_reader()`. Operators would be + inserted into the program. User should call :code:`start()` + before each epoch and catch :code:`fluid.core.EOFException` + thrown by :code:`Executor.run()` when epoch ends. Once the + exception is caught, user should call :code:`reset()` to reset + the reader manually. + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + EPOCH_NUM = 3 + ITER_NUM = 5 + BATCH_SIZE = 3 + + def reader_creator_random_image_and_label(height, width): + def reader(): + for i in range(ITER_NUM): + fake_image = np.random.uniform(low=0, + high=255, + size=[height, width]) + fake_label = np.ones([1]) + yield fake_image, fake_label + return reader + + image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + reader = fluid.io.PyReader(feed_list=[image, label], + capacity=4, + iterable=False) + + user_defined_reader = reader_creator_random_image_and_label(784, 784) + reader.decorate_sample_list_generator( + paddle.batch(user_defined_reader, batch_size=BATCH_SIZE)) + # definition of network is omitted + executor = fluid.Executor(fluid.CUDAPlace(0)) + executor.run(fluid.default_startup_program()) + for i in range(EPOCH_NUM): + reader.start() + while True: + try: + executor.run(feed=None) + except fluid.core.EOFException: + reader.reset() + break + + + 2. If iterable=True, the created PyReader object is decoupled with + the program. No operator would be inserted into the program. + In this case, the created reader is a Python generator, which + is iterable. User should feed the data yielded from PyReader + object into :code:`Executor.run(feed=...)`. + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + EPOCH_NUM = 3 + ITER_NUM = 5 + BATCH_SIZE = 10 + + def reader_creator_random_image(height, width): + def reader(): + for i in range(ITER_NUM): + yield np.random.uniform(low=0, high=255, size=[height, width]), + return reader + + image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') + reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=False) + + user_defined_reader = reader_creator_random_image(784, 784) + reader.decorate_sample_list_generator( + paddle.batch(user_defined_reader, batch_size=BATCH_SIZE), + fluid.core.CUDAPlace(0)) + # definition of network is omitted + executor = fluid.Executor(fluid.CUDAPlace(0)) + executor.run(fluid.default_main_program()) + + for _ in range(EPOCH_NUM): + for data in reader(): + executor.run(feed=data) + + + 3. If return_list=True, the return values would be presented as list instead of dict. + This is usually used in dygraph mode. + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + ITER_NUM = 5 + BATCH_SIZE = 10 + + def reader_creator_random_image(height, width): + def reader(): + for i in range(ITER_NUM): + yield np.random.uniform(low=0, high=255, size=[height, width]), \ + np.random.random_integers(low=0, high=9, size=[1]) + return reader + + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): + py_reader = fluid.io.PyReader(capacity=2, return_list=True) + user_defined_reader = reader_creator_random_image(784, 784) + py_reader.decorate_sample_list_generator( + paddle.batch(user_defined_reader, batch_size=BATCH_SIZE), + place) + for image, label in py_reader(): + relu = fluid.layers.relu(image) + """ + + def __init__(self, + feed_list=None, + capacity=None, + use_double_buffer=True, + iterable=True, + return_list=False): + self._loader = DataLoader.from_generator( + feed_list, capacity, use_double_buffer, iterable, return_list) + + @property + def queue(self): + return self._loader.queue + + @property + def iterable(self): + return self._loader.iterable + + def __iter__(self): + return self._loader.__iter__() + + def __next__(self): + return self._loader.__next__() def start(self): ''' Start the data feeding thread. Can only call when the reader object is not iterable. - Example: - .. code-block:: python - + Example: + .. code-block:: python + import paddle import paddle.fluid as fluid import numpy as np @@ -380,10 +762,8 @@ class PyReader(object): reader.reset() break - ''' - if not in_dygraph_mode(): - assert not self._iterable, "start() cannot be called when PyReader is iterable" - self._start() + ''' + self._loader.start() def reset(self): ''' @@ -420,35 +800,7 @@ class PyReader(object): break ''' - if not in_dygraph_mode(): - assert not self._iterable, "reset() cannot be called when PyReader is iterable" - self._reset() - - def _start(self): - def __thread_main__(): - try: - for tensors in self._tensor_reader(): - array = core.LoDTensorArray() - for item in tensors: - if not isinstance(item, core.LoDTensor): - tmp = core.LoDTensor() - tmp.set(item, core.CPUPlace()) - item = tmp - - array.append(item) - - if not self._queue.push(array): - break - - self._queue.close() - except Exception as ex: - self._queue.close() - logging.warn('Your decorated reader has raised an exception!') - six.reraise(*sys.exc_info()) - - self._thread = threading.Thread(target=__thread_main__) - self._thread.daemon = True - self._thread.start() + self._loader.reset() def decorate_sample_generator(self, sample_generator, @@ -512,36 +864,8 @@ class PyReader(object): executor.run(feed=data) ''' - assert batch_size > 0, "batch_size must be larger than 0" - if not in_dygraph_mode(): - has_lod = False - for f in self._feed_list: - if f.lod_level != 0: - has_lod = True - break - - if has_lod: - self.decorate_sample_list_generator( - paddle.batch( - sample_generator, - batch_size=batch_size, - drop_last=drop_last), - places=places) - else: - reader = BatchedTensorProvider( - feed_list=self._feed_list, - place=core.CPUPlace(), - batch_size=batch_size, - generator=sample_generator, - drop_last=drop_last) - self.decorate_batch_generator(reader, places=places) - else: - self.decorate_sample_list_generator( - paddle.batch( - sample_generator, - batch_size=batch_size, - drop_last=drop_last), - places=places) + self._loader.set_sample_generator(sample_generator, batch_size, + drop_last, places) def decorate_sample_list_generator(self, reader, places=None): ''' @@ -596,26 +920,7 @@ class PyReader(object): executor.run(feed=data) ''' - assert self._tensor_reader is None, \ - "Cannot reset the data source of PyReader" - if not in_dygraph_mode(): - with program_guard(Program(), Program()): - feeder = DataFeeder( - feed_list=self._feed_list, place=core.CPUPlace()) - paddle_reader = feeder.decorate_reader( - reader, multi_devices=False) - - def __tensor_reader_impl__(): - for slots in paddle_reader(): - yield [slots[var.name] for var in self._feed_list] - else: - provider = ListTensorProvider(reader, places) - - def __tensor_reader_impl__(): - for slots in provider(): - yield slots[0] - - self.decorate_batch_generator(__tensor_reader_impl__, places) + self._loader.set_sample_list_generator(reader, places) def decorate_batch_generator(self, reader, places=None): ''' @@ -667,9 +972,48 @@ class PyReader(object): executor.run(feed=data) ''' - assert self._tensor_reader is None, \ - "Cannot reset the data source of PyReader" - self._tensor_reader = reader - if self._iterable: - assert places is not None, "Places cannot be None when py_reader is iterable" - self._init_iterable(places) + self._loader.set_batch_generator(reader, places) + + +class DatasetLoader(DataLoaderBase): + def __init__(self, dataset, places, drop_last): + assert isinstance(dataset, + DatasetBase), "dataset must be type of DatasetBase" + assert not in_dygraph_mode( + ), "DatasetLoader is not supported in dygraph mode yet" + + thread_num = len(places) + + assert len(dataset.filelist) >= thread_num, \ + "Filelist number of dataset {} must be not less than place number {}".format(len(dataset.filelist), thread_num) + + if dataset.thread_num != 0 and dataset.thread_num != thread_num: + logging.warn('thread_num {} which is set in Dataset is ignored'. + format(dataset.thread_num)) + + dataset.set_thread(thread_num) + + if isinstance(dataset, + InMemoryDataset) and dataset.queue_num > thread_num: + logging.warn("queue_num {} which is set in Dataset is ignored". + format(dataset.queue_num)) + dataset.set_queue_num(thread_num) + + self._dataset = dataset + use_slots = [ + slot.name for slot in dataset.proto_desc.multi_slot_desc.slots + if slot.is_used + ] + + self._iterable_dataset = core.IterableDatasetWrapper( + dataset.dataset, use_slots, + _convert_places(places), dataset.proto_desc.batch_size, drop_last) + + def __iter__(self): + self._dataset._finish_to_run() + self._dataset._prepare_to_run() + self._iterable_dataset._start() + return self + + def __next__(self): + return self._iterable_dataset._next() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ec55fe36860..6b1971b1561 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -183,6 +183,12 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_imperative_debug_string) + +if (APPLE OR WIN32) + list(REMOVE_ITEM TEST_OPS test_dataset) + list(REMOVE_ITEM TEST_OPS test_dataset_dataloader) +endif() + # Some ops need to check results when gc is enabled # Currently, only ops that register NoNeedBufferVarsInference need to do this test set(TEST_OPS_WITH_GC diff --git a/python/paddle/fluid/tests/unittests/simple_nets.py b/python/paddle/fluid/tests/unittests/simple_nets.py index 959042a246d..f65bc7c3f43 100644 --- a/python/paddle/fluid/tests/unittests/simple_nets.py +++ b/python/paddle/fluid/tests/unittests/simple_nets.py @@ -16,10 +16,7 @@ import paddle.fluid as fluid import numpy as np -def simple_fc_net(use_feed=None): - img = fluid.layers.data(name='image', shape=[784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - +def simple_fc_net_with_inputs(img, label, class_num=10): hidden = img for _ in range(4): hidden = fluid.layers.fc( @@ -28,12 +25,18 @@ def simple_fc_net(use_feed=None): act='relu', bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=1.0))) - prediction = fluid.layers.fc(hidden, size=10, act='softmax') + prediction = fluid.layers.fc(hidden, size=class_num, act='softmax') loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.mean(loss) return loss +def simple_fc_net(use_feed=None): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + return simple_fc_net_with_inputs(img, label, class_num=10) + + def fc_with_batchnorm(use_feed=None): img = fluid.layers.data(name='image', shape=[784], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 0958d33d182..27557897ba2 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -28,6 +28,11 @@ import unittest class TestDataset(unittest.TestCase): """ TestCases for Dataset. """ + def setUp(self): + self.use_data_loader = False + self.epoch_num = 10 + self.drop_last = False + def test_dataset_create(self): """ Testcase for dataset create. """ try: @@ -174,13 +179,20 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) - for i in range(2): - try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - except ImportError as e: - pass - except Exception as e: - self.assertTrue(False) + if self.use_data_loader: + data_loader = fluid.io.DataLoader.from_dataset(dataset, + fluid.cpu_places(), + self.drop_last) + for i in range(self.epoch_num): + for data in data_loader(): + exe.run(fluid.default_main_program(), feed=data) + else: + for i in range(self.epoch_num): + try: + exe.train_from_dataset(fluid.default_main_program(), + dataset) + except Exception as e: + self.assertTrue(False) os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_b.txt") @@ -225,13 +237,20 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( ) else fluid.CUDAPlace(0)) exe.run(fluid.default_startup_program()) - for i in range(2): - try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - except ImportError as e: - pass - except Exception as e: - self.assertTrue(False) + if self.use_data_loader: + data_loader = fluid.io.DataLoader.from_dataset(dataset, + fluid.cpu_places(), + self.drop_last) + for i in range(self.epoch_num): + for data in data_loader(): + exe.run(fluid.default_main_program(), feed=data) + else: + for i in range(self.epoch_num): + try: + exe.train_from_dataset(fluid.default_main_program(), + dataset) + except Exception as e: + self.assertTrue(False) dataset.set_merge_by_lineid(slots_vars) dataset.preload_into_memory() @@ -277,13 +296,20 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) - for i in range(2): - try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - except ImportError as e: - pass - except Exception as e: - self.assertTrue(False) + if self.use_data_loader: + data_loader = fluid.io.DataLoader.from_dataset(dataset, + fluid.cpu_places(), + self.drop_last) + for i in range(self.epoch_num): + for data in data_loader(): + exe.run(fluid.default_main_program(), feed=data) + else: + for i in range(self.epoch_num): + try: + exe.train_from_dataset(fluid.default_main_program(), + dataset) + except Exception as e: + self.assertTrue(False) os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt") @@ -324,17 +350,31 @@ class TestDataset(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda( ) else fluid.CUDAPlace(0)) exe.run(fluid.default_startup_program()) - for i in range(2): - try: - exe.train_from_dataset(fluid.default_main_program(), dataset) - except ImportError as e: - pass - except Exception as e: - self.assertTrue(False) + if self.use_data_loader: + data_loader = fluid.io.DataLoader.from_dataset(dataset, + fluid.cpu_places(), + self.drop_last) + for i in range(self.epoch_num): + for data in data_loader(): + exe.run(fluid.default_main_program(), feed=data) + else: + for i in range(self.epoch_num): + try: + exe.train_from_dataset(fluid.default_main_program(), + dataset) + except Exception as e: + self.assertTrue(False) os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_b.txt") +class TestDatasetWithDataLoader(TestDataset): + def setUp(self): + self.use_data_loader = True + self.epoch_num = 10 + self.drop_last = False + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py b/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py new file mode 100644 index 00000000000..10aefbb222b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dataset_dataloader.py @@ -0,0 +1,221 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import numpy as np +import six +import os +import unittest +from simple_nets import simple_fc_net_with_inputs + +BATCH_SIZE = 32 +BATCH_NUM = 10 +EPOCH_NUM = 4 + +IMAGE_SHAPE = [2, 3] +LABEL_SHAPE = [1] + +ALL_WRITTEN_FILES = set() + + +def get_place_string(p): + if isinstance(p, (fluid.CPUPlace or fluid.CUDAPlace)): + tmp = fluid.core.Place() + tmp.set_place(p) + p = tmp + + if p._type() == fluid.CPUPlace()._type(): + return 'CPUPlace()' + else: + return 'CUDAPlace()' + + +def remove_all_written_files(): + for filename in ALL_WRITTEN_FILES: + os.remove(filename) + + +def write_reader_data_to_file(filename, reader): + ALL_WRITTEN_FILES.add(filename) + with open(filename, 'w') as fid: + for instance_list in reader(): + for i, instance in enumerate(instance_list): + instance = np.reshape(instance, [instance.size, ]) + fid.write(str(instance.size) + ' ') + fid.write(' '.join(map(str, instance))) + fid.write(' ') + + fid.write('\n') + + +def fake_reader(batch_size=BATCH_SIZE, batch_num=BATCH_NUM): + def __reader__(): + iteration = BATCH_SIZE * BATCH_NUM + iteration = int(iteration + BATCH_SIZE / 2) + for _ in six.moves.range(iteration): + image = np.random.random(size=IMAGE_SHAPE).astype('float32') + label = np.random.random_integers( + size=LABEL_SHAPE, low=0, high=9).astype('int64') + yield image, label + + return __reader__ + + +class DatasetLoaderTestBase(unittest.TestCase): + def setUp(self): + self.dataset_name = "QueueDataset" + self.drop_last = False + + def tearDown(self): + return + remove_all_written_files() + + def build_network(self): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + image = fluid.layers.data( + name='image', shape=IMAGE_SHAPE, dtype='float32') + label = fluid.layers.data( + name='label', shape=LABEL_SHAPE, dtype='int64') + + simple_fc_net_with_inputs(image, label) + + return main_prog, startup_prog, [image, label] + + def check_batch_number(self, place, randomize_batch_num=False): + main_prog, startup_prog, feeds = self.build_network() + dataset = fluid.DatasetFactory().create_dataset(self.dataset_name) + dataset.set_batch_size(BATCH_SIZE) + + if isinstance(place, fluid.CPUPlace): + file_num = 10 + os.environ['CPU_NUM'] = str(file_num) + places = fluid.cpu_places() + use_cuda = False + else: + file_num = fluid.core.get_cuda_device_count() + places = fluid.cuda_places() + use_cuda = True + + filelist = [] + if file_num > 1 and randomize_batch_num: + random_delta_batch_size = np.random.random_integers( + low=-BATCH_NUM / 2, high=BATCH_NUM / 2, size=[file_num]) + random_delta_batch_size[-1] = -int( + np.sum(random_delta_batch_size[0:-1])) + else: + random_delta_batch_size = np.zeros(shape=[file_num]) + + for i in six.moves.range(file_num): + filename = 'dataset_test_{}.txt'.format(i) + filelist.append(filename) + write_reader_data_to_file( + filename, + fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i])) + + dataset.set_filelist(filelist) + dataset.set_use_var(feeds) + dataset.set_pipe_command("cat") + if self.dataset_name == 'InMemoryDataset': + dataset.load_into_memory() + + dataloader = fluid.io.DataLoader.from_dataset( + dataset=dataset, places=places, drop_last=self.drop_last) + prog = fluid.CompiledProgram(main_prog).with_data_parallel() + exe = fluid.Executor(place) + + exe.run(startup_prog) + + for _ in six.moves.range(EPOCH_NUM): + has_complete_batch = False + for batch_id, data in enumerate(dataloader): + self.assertEquals(len(places), len(data)) + for idx, data_on_each_device in enumerate(data): + image = data_on_each_device["image"] + label = data_on_each_device["label"] + + if self.drop_last: + batch_size = BATCH_SIZE + else: + if batch_id == BATCH_NUM: + batch_size = BATCH_SIZE / 2 + else: + batch_size = BATCH_SIZE + + self.assertEquals(image.shape()[1:], IMAGE_SHAPE) + self.assertTrue( + image._place()._equals(places[idx]), + msg=get_place_string(image._place()) + ' vs ' + + get_place_string(places[idx])) + if self.drop_last: + self.assertEquals(image.shape()[0], BATCH_SIZE) + else: + self.assertTrue(image.shape()[0] == BATCH_SIZE or + image.shape()[0] == BATCH_SIZE / 2) + + self.assertEquals(label.shape()[1:], LABEL_SHAPE) + self.assertTrue(label._place()._equals(places[idx])) + if self.drop_last: + self.assertEquals(label.shape()[0], BATCH_SIZE) + else: + self.assertTrue(label.shape()[0] == BATCH_SIZE or + label.shape()[0] == BATCH_SIZE / 2) + + self.assertEquals(image.shape()[0], label.shape()[0]) + + if image.shape()[0] == BATCH_SIZE: + has_complete_batch = True + + exe.run(prog, feed=data) + + self.assertTrue(has_complete_batch) + + def get_all_places(self): + p = [fluid.CPUPlace()] + if fluid.is_compiled_with_cuda(): + p.append(fluid.CUDAPlace(0)) + return p + + def test_batch_number_with_same_length_files(self): + for p in self.get_all_places(): + with fluid.scope_guard(fluid.Scope()): + self.check_batch_number(place=p, randomize_batch_num=False) + + def test_batch_number_with_different_length_files(self): + for p in self.get_all_places(): + with fluid.scope_guard(fluid.Scope()): + self.check_batch_number(place=p, randomize_batch_num=True) + + +class QueueDatasetTestWithoutDropLast(DatasetLoaderTestBase): + def setUp(self): + self.dataset_name = "QueueDataset" + self.drop_last = True + + +class InMemoryDatasetTestWithoutDropLast(DatasetLoaderTestBase): + def setUp(self): + self.dataset_name = "InMemoryDataset" + self.drop_last = False + + +class InMemoryDatasetTestWithDropLast(DatasetLoaderTestBase): + def setUp(self): + self.dataset_name = "InMemoryDataset" + self.drop_last = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader_data_check.py b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader_data_check.py index 3a1b6837957..4d767709ef5 100644 --- a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader_data_check.py +++ b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader_data_check.py @@ -22,21 +22,23 @@ import six class TestClass(unittest.TestCase): def setUp(self): self.use_double_buffer = True + self.use_py_reader = True def test_reader_data(self): img_shape = [28, 31] label_shape = [1] batch_size = 32 + batch_num = 10 def fake_reader(): - for _ in six.moves.range(batch_size * 10): + for _ in six.moves.range(batch_size * batch_num): img = np.random.random(size=img_shape).astype('float32') label = np.random.random_integers( low=0, high=9, size=label_shape).astype('int64') yield img, label - reader = paddle.reader.cache(fake_reader) - batch_reader = paddle.batch(reader, batch_size=batch_size) + reader = fluid.io.cache(fake_reader) + batch_reader = fluid.io.batch(reader, batch_size=batch_size) places = [fluid.CPUPlace()] if fluid.core.is_compiled_with_cuda(): @@ -58,37 +60,67 @@ class TestClass(unittest.TestCase): ) and not use_double_buffer: use_double_buffer = True - py_reader = fluid.io.PyReader( - feed_list=[img, label], - capacity=4, - iterable=True, - use_double_buffer=use_double_buffer) - py_reader.decorate_sample_list_generator(batch_reader, places=p) + if self.use_py_reader: + py_reader = fluid.io.PyReader( + feed_list=[img, label], + capacity=4, + iterable=True, + use_double_buffer=use_double_buffer) + py_reader.decorate_sample_list_generator( + batch_reader, places=p) + else: + py_reader = fluid.io.DataLoader.from_generator( + feed_list=[img, label], + capacity=4, + iterable=True, + use_double_buffer=use_double_buffer + ).set_sample_list_generator( + batch_reader, places=p) + + for break_beforehand in [True, False]: + for epoch_id in six.moves.range(10): + gen = batch_reader() + batch_id = 0 + for d in py_reader(): + feed = feeder.feed(next(gen)) + I1, L1 = feed['image'], feed['label'] + I2, L2 = d[0]['image'], d[0]['label'] + + I1 = np.array(I1) + I2 = np.array(I2) + L1 = np.array(L1) + L2 = np.array(L2) + + self.assertTrue(np.array_equal(I1, I2)) + self.assertTrue(np.array_equal(L1, L2)) + + batch_id += 1 + if break_beforehand and batch_id >= int(batch_num / + 2): + break + + if break_beforehand: + self.assertTrue(next(gen, None) is not None) + else: + self.assertTrue(next(gen, None) is None) - for epoch_id in six.moves.range(10): - gen = batch_reader() - batch_id = 0 - for d in py_reader(): - feed = feeder.feed(next(gen)) - I1, L1 = feed['image'], feed['label'] - I2, L2 = d[0]['image'], d[0]['label'] - I1 = np.array(I1) - I2 = np.array(I2) - L1 = np.array(L1) - L2 = np.array(L2) - - self.assertTrue(np.array_equal(I1, I2)) - self.assertTrue(np.array_equal(L1, L2)) +class TestClass2(TestClass): + def setUp(self): + self.use_double_buffer = False + self.use_py_reader = True - batch_id += 1 - self.assertTrue(next(gen, None) is None) +class TestClass3(TestClass): + def setUp(self): + self.use_double_buffer = True + self.use_py_reader = False -class TestClass2(TestClass): +class TestClass4(TestClass): def setUp(self): self.use_double_buffer = False + self.use_py_reader = False if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_generator_dataloader.py b/python/paddle/fluid/tests/unittests/test_generator_dataloader.py new file mode 100644 index 00000000000..0945b59321a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_generator_dataloader.py @@ -0,0 +1,196 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import time +import six +import unittest +from paddle.fluid.reader import DataLoaderBase + +EPOCH_NUM = 20 +BATCH_SIZE = 32 +BATCH_NUM = 20 +CLASS_NUM = 10 + + +def random_reader(): + np.random.seed(1) + for i in range(BATCH_SIZE * BATCH_NUM): + image = np.random.random([784]) + label = np.random.random_integers(low=0, high=CLASS_NUM - 1) + yield image, label + + +def simple_fc_net(places, use_legacy_py_reader, use_double_buffer): + startup_prog = fluid.Program() + main_prog = fluid.Program() + startup_prog.random_seed = 1 + main_prog.random_seed = 1 + + with fluid.unique_name.guard(): + with fluid.program_guard(main_prog, startup_prog): + image = fluid.layers.data( + name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + py_reader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=4, + iterable=not use_legacy_py_reader, + use_double_buffer=use_double_buffer) + hidden = image + for hidden_size in [10, 20, 30]: + hidden = fluid.layers.fc( + hidden, + size=hidden_size, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + predict_label = fluid.layers.fc(hidden, + size=CLASS_NUM, + act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return startup_prog, main_prog, py_reader, loss + + +class TestBase(unittest.TestCase): + def run_main(self, use_legacy_py_reader, with_data_parallel, places, + use_double_buffer): + scope = fluid.Scope() + with fluid.scope_guard(scope): + startup_prog, main_prog, py_reader, loss = simple_fc_net( + places, use_legacy_py_reader, use_double_buffer) + + reader = paddle.batch(random_reader, batch_size=BATCH_SIZE) + + ps = places if use_double_buffer else fluid.cpu_places(len(places)) + + py_reader.set_sample_list_generator( + reader, places=ps if py_reader.iterable else None) + + exe = fluid.Executor(place=places[0]) + exe.run(startup_prog) + + prog = fluid.CompiledProgram(main_prog) + if with_data_parallel: + prog = prog.with_data_parallel( + loss_name=loss.name, places=places) + + step = 0 + step_list = [] + loss_list = [] + start_t = time.time() + if not py_reader.iterable: + for _ in six.moves.range(EPOCH_NUM): + step = 0 + py_reader.start() + while True: + try: + L, = exe.run(program=prog, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + except fluid.core.EOFException: + py_reader.reset() + break + step_list.append(step) + else: + for _ in six.moves.range(EPOCH_NUM): + step = 0 + for d in py_reader(): + print(d) + assert len(d) == len(places), "{} != {}".format( + len(d), len(places)) + for i, item in enumerate(d): + image = item['image'] + label = item['label'] + assert image.shape() == [BATCH_SIZE, 784] + assert label.shape() == [BATCH_SIZE, 1] + assert image._place()._equals(ps[i]) + assert label._place()._equals(ps[i]) + L, = exe.run(program=prog, + feed=d, + fetch_list=[loss], + use_program_cache=True) + loss_list.append(np.mean(L)) + step += 1 + step_list.append(step) + end_t = time.time() + ret = { + "time": end_t - start_t, + "step": step_list, + "loss": np.array(loss_list) + } + return ret + + def prepare_places(self, with_data_parallel, with_cpu=True, with_gpu=True): + places = [] + if with_cpu: + places.append([fluid.CPUPlace()]) + if with_data_parallel: + places.append([fluid.CPUPlace()] * 2) + + if with_gpu and fluid.core.is_compiled_with_cuda(): + tmp = fluid.cuda_places() + assert len(tmp) > 0, "no gpu detected" + if with_data_parallel: + places.append(tmp) + places.append([tmp[0]]) + return places + + def test_main(self): + for with_data_parallel in [True, False]: + for p in self.prepare_places(with_data_parallel): + for use_double_buffer in [False, True]: + results = [] + for use_legacy_py_reader in [False, True]: + print(p, use_double_buffer, use_legacy_py_reader) + ret = self.run_main( + use_legacy_py_reader=use_legacy_py_reader, + with_data_parallel=with_data_parallel, + places=p, + use_double_buffer=use_double_buffer) + results.append(ret) + if not use_double_buffer: + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss'])) + self.assertLess(diff, 1e-3) + + +class TestDataLoaderBaseAbstract(unittest.TestCase): + def test_main(self): + loader = DataLoaderBase() + try: + loader.__iter__() + self.assertTrue(False) + except NotImplementedError: + self.assertTrue(True) + + try: + loader.__next__() + self.assertTrue(False) + except NotImplementedError: + self.assertTrue(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pyreader.py b/python/paddle/fluid/tests/unittests/test_py_reader_return_list.py similarity index 87% rename from python/paddle/fluid/tests/unittests/test_pyreader.py rename to python/paddle/fluid/tests/unittests/test_py_reader_return_list.py index c65adf1595d..c6e18565078 100644 --- a/python/paddle/fluid/tests/unittests/test_pyreader.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_return_list.py @@ -55,19 +55,12 @@ class TestPyReader(unittest.TestCase): for _ in range(self.epoch_num): for data in reader(): if return_list: - executor.run(feed={"image": data[0]}) + executor.run(feed={"image": data[0][0]}) else: executor.run(feed=data) with fluid.dygraph.guard(): - batch_py_reader = fluid.io.PyReader( - feed_list=[ - np.empty( - [self.batch_size, 784, 784], dtype='float32') - ], - capacity=2, - use_double_buffer=True, - return_list=return_list) + batch_py_reader = fluid.io.PyReader(capacity=2) user_defined_reader = reader_creator_random_image(784, 784) batch_py_reader.decorate_sample_generator( user_defined_reader, -- GitLab