From 1b59220d50576661d6afb4bd7f6ea799a460bd6f Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 27 Jun 2018 05:42:19 +0000 Subject: [PATCH] complete python reader op python side --- paddle/fluid/pybind/pybind.cc | 2 + python/paddle/fluid/__init__.py | 3 +- python/paddle/fluid/layers/io.py | 89 ++++++- .../unittests/test_py_reader_push_pop.py | 99 ++++++++ .../test_py_reader_using_executor.py | 224 ++++++++++++++++++ 5 files changed, 413 insertions(+), 4 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py create mode 100644 python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c88fbef63cf..069a59c7b49 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -449,6 +449,8 @@ All parameter, weight, gradient are variables in Paddle. }); py::class_(m, "LoDTensorArray") + .def("__init__", + [](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); }) .def("__getitem__", [](LoDTensorArray &self, size_t i) { return &self.at(i); }, py::return_value_policy::reference) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index bd985ad733a..e1944ebc7b6 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -44,7 +44,7 @@ import metrics import transpiler from param_attr import ParamAttr, WeightNormParamAttr from data_feeder import DataFeeder -from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace +from core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope from transpiler import DistributeTranspiler, InferenceTranspiler, \ memory_optimize, release_memory from concurrency import (Go, make_channel, channel_send, channel_recv, @@ -72,6 +72,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ 'backward', 'regularizer', 'LoDTensor', + 'LoDTensorArray', 'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace', diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9de88e2c320..d883c9e4ade 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -22,9 +22,10 @@ from ..executor import global_scope from layer_function_generator import generate_layer_fn, templatedoc __all__ = [ - 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer', - 'random_data_generator', 'Preprocessor', 'load' + 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv', + 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch', + 'double_buffer', 'random_data_generator', 'py_reader', 'Preprocessor', + 'load' ] @@ -431,6 +432,88 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): return monkey_patch_reader_methods(main_prog_var) +def py_reader(capacity, shapes, lod_levels, dtypes, for_parallel=True): + """ + Create a reader and blocking queue for data feeding in Python + + This layer returns a Reader Variable and a BlockingQueue. + The BlockingQueue provides `push()` method to push a + `LoDTensorArray` object into the queue in Python side. In C++ + side, the Reader Variable would invoke `pop()` method of the + queue to retrieve the feeding data. The process of feeding data + in Python side and fetching data in C++ side can run in parallel. + The BlockingQueue should be closed using `close()` method when + unused. + + Args: + capacity(int): The maximum capacity of the BlockingQueue. + shapes(list): List of tuples which declaring data shapes. + lod_levels(list): List of ints which declaring data lod_level. + dtypes(list): List of strs which declaring data type. + for_parallel(Bool): Set it as True if you are going to run + subsequent operators in parallel. + + Returns: + Variable: A Reader Variable from which we can get feeding data. + BlockingQueue: A blocking queue for data feeding. + + Examples: + + .. code-block:: python + + reader, queue = fluid.layers.py_reader( + capacity=10, + shapes=[[-1,3,224,224], [-1,1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + # Via the reader, we can use 'read_file' layer to get data: + image, label = fluid.layers.read_file(reader) + # Via the blocking queue, we can feed data using threads + def feed_data(queue, feed_images, feed_labels): + for feed_image, feed_label in zip(feed_images, feed_labels): + data = core.LoDTensorArray() + data.append(feed_image) + data.append(feed_label) + queue.push(data) + + thread = threading.Thread(target=feed_data, args=(queue, feed_images, feed_labels)) + """ + dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] + shape_concat = [] + ranks = [] + + for shape in shapes: + shape_concat.extend(shape) + ranks.append(len(shape)) + + queue_name = unique_name('lod_tensor_blocking_queue') + var = global_scope().var(queue_name) + feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=unique_name('create_py_reader')) + startup_blk.append_op( + type='create_py_reader', + inputs={'blocking_queue': queue_name}, + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'ranks': ranks + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + + main_prog_var = _copy_reader_var_(default_main_program().current_block(), + startup_var) + + if for_parallel: + main_prog_var = parallel(reader=main_prog_var) + + return monkey_patch_reader_methods(main_prog_var), feed_queue + + def open_files(filenames, shapes, lod_levels, diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py new file mode 100644 index 00000000000..05715464848 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import numpy as np +from threading import Thread + + +def feed_data(feed_queue, inputs): + for in_data in inputs: + feed_queue.push(in_data) + + +class TestPyReader(unittest.TestCase): + def setUp(self): + self.capacity = 10 + self.batch_size_min = 10 + self.batch_size_max = 20 + self.shapes = [(-1, 3, 2, 1), (-1, 1)] + self.lod_levels = [0, 0] + self.dtypes = ['float32', 'int64'] + self.iterations = 20 + + def test_single_thread_main(self): + self.main(use_thread=False) + + def test_multiple_thread_main(self): + self.main(use_thread=True) + + def main(self, use_thread=False): + with fluid.program_guard(fluid.Program(), fluid.Program()): + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + executor = fluid.Executor(place) + + data_file, feed_queue = fluid.layers.py_reader( + capacity=self.capacity, + dtypes=self.dtypes, + lod_levels=self.lod_levels, + shapes=self.shapes) + + read_out_data = fluid.layers.read_file(data_file) + self.inputs = [] + + for i in range(self.iterations): + in_data = fluid.LoDTensorArray() + batch_size = np.random.random_integers(self.batch_size_min, + self.batch_size_max) + for shape, dtype in zip(self.shapes, self.dtypes): + next_data = np.random.uniform( + low=0, high=1000, + size=(batch_size, ) + shape[1:]).astype(dtype) + in_data.append(executor.as_lodtensor(next_data)) + + self.inputs.append(in_data) + + executor.run(fluid.default_startup_program()) + self.outputs = [] + if use_thread: + thread = Thread( + target=feed_data, args=(feed_queue, self.inputs)) + thread.start() + for in_data in self.inputs: + self.outputs.append( + executor.run(fetch_list=list(read_out_data))) + else: + for in_data in self.inputs: + feed_queue.push(in_data) + self.outputs.append( + executor.run(fetch_list=list(read_out_data))) + + feed_queue.close() + self.validate() + + def validate(self): + self.assertEqual(len(self.inputs), len(self.outputs)) + for in_data_list, out_data_list in zip(self.inputs, self.outputs): + self.assertEqual(len(in_data_list), len(out_data_list)) + in_data_list_np = [ + np.array(in_lod_tensor) for in_lod_tensor in in_data_list + ] + for in_data, out_data in zip(in_data_list_np, out_data_list): + self.assertTrue((in_data == out_data).all()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py new file mode 100644 index 00000000000..98d48097ee1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py @@ -0,0 +1,224 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +import threading +import multiprocessing +import os + + +def as_tensor(np_array_or_tensor, place=None): + if isinstance(np_array_or_tensor, fluid.LoDTensor): + return np_array_or_tensor + + if place is None: + place = fluid.CPUPlace() + + tensor = fluid.LoDTensor() + tensor.set(np_array_or_tensor, place) + return tensor + + +def as_numpy(tensor_or_numpy): + return tensor_or_numpy if isinstance( + tensor_or_numpy, np.ndarray) else np.array(tensor_or_numpy) + + +def feed_data(feed_queue, reader): + data_generator = reader() + while True: + data = next(data_generator, None) + if data is None or not feed_queue.push(data): + break + + +def simple_fc_net(in_size, + class_num, + hidden_sizes, + batch_size, + queue_capacity, + use_double_buffer=False): + reader, feed_queue = fluid.layers.py_reader( + capacity=queue_capacity, + shapes=[[-1, in_size], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + reader = fluid.layers.batch(reader, batch_size=batch_size) + if use_double_buffer: + reader = fluid.layers.double_buffer(reader) + + in_data, label = fluid.layers.read_file(reader) + + hidden = in_data + for hidden_size in hidden_sizes: + hidden = fluid.layers.fc( + hidden, + size=hidden_size, + act='tanh', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + predict_label = fluid.layers.fc(hidden, size=class_num, act='softmax') + loss = fluid.layers.mean( + fluid.layers.cross_entropy( + input=predict_label, label=label)) + + optimizer = fluid.optimizer.Adam() + optimizer.minimize(loss) + return in_data, label, loss, optimizer, feed_queue + + +class TestPyReaderUsingExecutor(unittest.TestCase): + def setUp(self): + self.in_size = 1000 + self.hidden_sizes = [50, 30, 20] + self.class_num = 10 + self.batch_size = 32 + self.iterations = 10 + self.queue_capacity = 50 + + def test(self): + for use_cuda in [False, True]: + for use_parallel_executor in [False, True]: + for use_double_buffer in [False, True]: + print('Test Parameters:'), + print({ + 'use_cuda': use_cuda, + 'use_parallel_executor': use_parallel_executor, + 'use_double_buffer': use_double_buffer + }) + self.main(use_cuda, use_parallel_executor, + use_double_buffer) + + def random_reader(self): + def reader(): + self.inputs = [] + cnt = 0 + while True: + tensors = fluid.LoDTensorArray() + in_data = np.random.uniform( + low=0, high=1, size=(1, self.in_size)).astype('float32') + tensors.append(as_tensor(in_data)) + label = np.random.random_integers( + low=0, high=self.class_num - 1, size=(1, 1)).astype('int64') + tensors.append(as_tensor(label)) + + if cnt < self.iterations * self.batch_size * self.batch_size_times: + if cnt % (self.batch_size * self.batch_size_times) == 0: + self.inputs.append([in_data, label]) + else: + self.inputs[-1][0] = np.concatenate( + (self.inputs[-1][0], in_data), axis=0) + self.inputs[-1][1] = np.concatenate( + (self.inputs[-1][1], label), axis=0) + elif not self.use_double_buffer: + break + + yield tensors + cnt += 1 + + yield None + + return reader + + def main(self, + use_cuda=True, + use_parallel_executor=False, + use_double_buffer=False): + assert not use_cuda or use_cuda and core.is_compiled_with_cuda() + + self.use_cuda = use_cuda + self.use_parallel_executor = use_parallel_executor + self.use_double_buffer = use_double_buffer + + startup_program = fluid.Program() + main_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + in_data, label, loss, optimizer, feed_queue = simple_fc_net( + in_size=self.in_size, + class_num=self.class_num, + hidden_sizes=self.hidden_sizes, + batch_size=self.batch_size, + queue_capacity=self.queue_capacity, + use_double_buffer=self.use_double_buffer) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + + startup_exe = fluid.Executor(place) + startup_exe.run(startup_program) + + if use_parallel_executor: + main_exe = fluid.ParallelExecutor(use_cuda, loss_name=loss.name) + if use_cuda: + self.batch_size_times = core.get_cuda_device_count() + else: + self.batch_size_times = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + else: + main_exe = startup_exe + self.batch_size_times = 1 + + reader = self.random_reader() + thread = threading.Thread( + target=feed_data, args=(feed_queue, reader)) + thread.start() + + self.outputs = [] + for _ in range(self.iterations): + fetches = main_exe.run(fetch_list=[in_data.name, label.name]) + fetches = [as_numpy(fetch) for fetch in fetches] + self.outputs.append(fetches) + + self.validate() + feed_queue.close() + + def validate(self): + self.assertEqual(len(self.inputs), len(self.outputs)) + for batch_in, batch_out in zip(self.inputs, self.outputs): + self.assertEqual(len(batch_in), len(batch_out)) + if self.use_parallel_executor and not self.use_double_buffer: + self.validate_unordered_batch(batch_in, batch_out) + else: + for in_data, out_data in zip(batch_in, batch_out): + self.assertEqual(in_data.shape, out_data.shape) + if not self.use_parallel_executor: + self.assertTrue((in_data == out_data).all()) + + def validate_unordered_batch(self, batch_in, batch_out): + out_index_left_set = set(range(self.batch_size * self.batch_size_times)) + mapping_num = 0 + for i in range(self.batch_size * self.batch_size_times): + for j in out_index_left_set: + flag = True + for k in range(len(batch_in)): + in_data = batch_in[k][i] + out_data = batch_out[k][j] + if (in_data != out_data).any(): + flag = False + break + + if flag: + out_index_left_set.remove(j) + mapping_num += 1 + break + + self.assertEqual(mapping_num, self.batch_size * self.batch_size_times) + + +if __name__ == '__main__': + unittest.main() -- GitLab