未验证 提交 ebe3b5e7 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #11853 from sneaxiy/complete_py_reader_python

Add Python Reader Op (Python side and unittests)
......@@ -29,11 +29,11 @@ enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
void ReadNext(std::vector<LoDTensor>* out);
virtual void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void Shutdown();
void Start();
virtual void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
......@@ -42,7 +42,7 @@ class ReaderBase {
virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ReadNextImpl(std::vector<LoDTensor>* out) {}
virtual void ShutdownImpl() {}
......
......@@ -81,6 +81,15 @@ class BlockingQueue {
}
}
void ReOpen() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = false;
std::deque<T> new_deque;
queue_.swap(new_deque);
send_cv_.notify_all();
receive_cv_.notify_all();
}
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
......
......@@ -27,19 +27,17 @@ class PyReader : public framework::FileReader {
queue_ = queue;
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
private:
void ShutdownImpl() override { /* TODO */
}
void Shutdown() override { queue_->Close(); }
void StartImpl() override { /* TODO */
}
void Start() override { queue_->ReOpen(); }
private:
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
......
......@@ -58,12 +58,15 @@ class LoDTensorBlockingQueue {
inline size_t Size() const { return queue_.Size(); }
inline void Close() { return queue_.Close(); }
inline void ReOpen() { queue_.ReOpen(); }
inline void Close() { queue_.Close(); }
inline bool IsClosed() const { return queue_.IsClosed(); }
private:
void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
void CheckDims(
const std::vector<framework::LoDTensor>& lod_tensor_vec) const {
PADDLE_ENFORCE(dims_.size() == lod_tensor_vec.size(),
"Expect input size is %d but found %s", dims_.size(),
lod_tensor_vec.size());
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <Python.h>
#include <algorithm>
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <string>
#include <unordered_map>
......@@ -310,7 +311,8 @@ All parameter, weight, gradient are variables in Paddle.
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "")
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
......@@ -325,7 +327,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity,
const std::vector<std::vector<int64_t>> &shapes)
-> LoDTensorBlockingQueue * {
-> std::shared_ptr<LoDTensorBlockingQueue> {
std::vector<DDim> dims(shapes.size());
std::transform(shapes.begin(), shapes.end(), dims.begin(),
[](const std::vector<int64_t> &shape) {
......@@ -333,9 +335,9 @@ All parameter, weight, gradient are variables in Paddle.
});
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, dims);
return holder->GetQueue().get();
return holder->GetQueue();
},
py::return_value_policy::reference);
py::return_value_policy::copy);
py::class_<Scope>(m, "Scope", "")
.def("var",
......@@ -543,6 +545,8 @@ All parameter, weight, gradient are variables in Paddle.
});
py::class_<LoDTensorArray>(m, "LoDTensorArray")
.def("__init__",
[](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); })
.def("__getitem__",
[](LoDTensorArray &self, size_t i) { return &self.at(i); },
py::return_value_policy::reference)
......
......@@ -44,7 +44,7 @@ import metrics
import transpiler
from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
from core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
from transpiler import DistributeTranspiler, InferenceTranspiler, \
memory_optimize, release_memory
from concurrency import (Go, make_channel, channel_send, channel_recv,
......@@ -72,6 +72,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \
'backward',
'regularizer',
'LoDTensor',
'LoDTensorArray',
'CPUPlace',
'CUDAPlace',
'CUDAPinnedPlace',
......
......@@ -24,7 +24,8 @@ from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
'double_buffer', 'random_data_generator', 'py_reader', 'Preprocessor',
'load'
]
......@@ -445,6 +446,88 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, dtypes, lod_levels=None):
"""
Create a reader and blocking queue for data feeding in Python
This layer returns a Reader Variable and a BlockingQueue.
The BlockingQueue provides `push()` method to push a `LoDTensorArray`
object into the queue in Python side. In C++ side, the Reader
Variable would invoke `pop()` method of the queue to retrieve the
feeding data. The process of feeding data in Python side and fetching
data in C++ side can run in parallel. The BlockingQueue should be closed
using `close()` method when unused.
Args:
capacity(int): The maximum capacity of the BlockingQueue.
shapes(list): List of tuples which declaring data shapes.
dtypes(list): List of strs which declaring data type.
lod_levels(list): List of ints which declaring data lod_level.
Returns:
tuple(Variable, BlockingQueue):
A Reader Variable from which we can get feeding data.
A BlockingQueue object for data feeding.
Examples:
.. code-block:: python
reader, queue = fluid.layers.py_reader(
capacity=10,
shapes=[[-1,3,224,224], [-1,1]],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
# Via the blocking queue, we can feed data using threads
def feed_data(queue, feed_images, feed_labels):
for feed_image, feed_label in zip(feed_images, feed_labels):
data = core.LoDTensorArray()
data.append(feed_image)
data.append(feed_label)
queue.push(data)
thread = threading.Thread(target=feed_data, args=(queue, feed_images, feed_labels))
thread.start()
"""
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
if lod_levels is None:
lod_levels = [0] * len(shapes)
queue_name = unique_name('lod_tensor_blocking_queue')
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=unique_name('create_py_reader'))
startup_blk.append_op(
type='create_py_reader',
inputs={'blocking_queue': queue_name},
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var), feed_queue
def open_files(filenames,
shapes,
lod_levels,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import numpy as np
from threading import Thread
def feed_data(feed_queue, inputs):
for in_data in inputs:
feed_queue.push(in_data)
class TestPyReader(unittest.TestCase):
def setUp(self):
self.capacity = 10
self.batch_size_min = 10
self.batch_size_max = 20
self.shapes = [(-1, 3, 2, 1), (-1, 1)]
self.lod_levels = [0, 0]
self.dtypes = ['float32', 'int64']
self.iterations = 20
def test_single_thread_main(self):
self.main(use_thread=False)
def test_multiple_thread_main(self):
self.main(use_thread=True)
def main(self, use_thread=False):
with fluid.program_guard(fluid.Program(), fluid.Program()):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
executor = fluid.Executor(place)
data_file, feed_queue = fluid.layers.py_reader(
capacity=self.capacity,
dtypes=self.dtypes,
lod_levels=self.lod_levels,
shapes=self.shapes)
read_out_data = fluid.layers.read_file(data_file)
self.inputs = []
for i in range(self.iterations):
in_data = fluid.LoDTensorArray()
batch_size = np.random.random_integers(self.batch_size_min,
self.batch_size_max)
for shape, dtype in zip(self.shapes, self.dtypes):
next_data = np.random.uniform(
low=0, high=1000,
size=(batch_size, ) + shape[1:]).astype(dtype)
in_data.append(executor.as_lodtensor(next_data))
self.inputs.append(in_data)
executor.run(fluid.default_startup_program())
self.outputs = []
if use_thread:
thread = Thread(
target=feed_data, args=(feed_queue, self.inputs))
thread.start()
for in_data in self.inputs:
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))
else:
for in_data in self.inputs:
feed_queue.push(in_data)
self.outputs.append(
executor.run(fetch_list=list(read_out_data)))
feed_queue.close()
self.validate()
def validate(self):
self.assertEqual(len(self.inputs), len(self.outputs))
for in_data_list, out_data_list in zip(self.inputs, self.outputs):
self.assertEqual(len(in_data_list), len(out_data_list))
in_data_list_np = [
np.array(in_lod_tensor) for in_lod_tensor in in_data_list
]
for in_data, out_data in zip(in_data_list_np, out_data_list):
self.assertTrue((in_data == out_data).all())
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import threading
import multiprocessing
import os
def as_tensor(np_array_or_tensor, place=None):
if isinstance(np_array_or_tensor, fluid.LoDTensor):
return np_array_or_tensor
if place is None:
place = fluid.CPUPlace()
tensor = fluid.LoDTensor()
tensor.set(np_array_or_tensor, place)
return tensor
def as_numpy(tensor_or_numpy):
return tensor_or_numpy if isinstance(
tensor_or_numpy, np.ndarray) else np.array(tensor_or_numpy)
def feed_data(feed_queue, reader):
data_generator = reader()
while True:
data = next(data_generator, None)
if data is None or not feed_queue.push(data):
break
def simple_fc_net(in_size,
class_num,
hidden_sizes,
batch_size,
queue_capacity,
use_double_buffer=False):
reader, feed_queue = fluid.layers.py_reader(
capacity=queue_capacity,
shapes=[[-1, in_size], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.batch(reader, batch_size=batch_size)
if use_double_buffer:
reader = fluid.layers.double_buffer(reader)
in_data, label = fluid.layers.read_file(reader)
hidden = in_data
for hidden_size in hidden_sizes:
hidden = fluid.layers.fc(
hidden,
size=hidden_size,
act='tanh',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
predict_label = fluid.layers.fc(hidden, size=class_num, act='softmax')
loss = fluid.layers.mean(
fluid.layers.cross_entropy(
input=predict_label, label=label))
optimizer = fluid.optimizer.Adam()
optimizer.minimize(loss)
return in_data, label, loss, optimizer, feed_queue
class TestPyReaderUsingExecutor(unittest.TestCase):
def setUp(self):
self.in_size = 1000
self.hidden_sizes = [50, 30, 20]
self.class_num = 10
self.batch_size = 32
self.iterations = 10
self.queue_capacity = 50
def test(self):
for use_cuda in [False, True]:
for use_parallel_executor in [False, True]:
for use_double_buffer in [False, True]:
print('Test Parameters:'),
print({
'use_cuda': use_cuda,
'use_parallel_executor': use_parallel_executor,
'use_double_buffer': use_double_buffer
})
self.main(use_cuda, use_parallel_executor,
use_double_buffer)
def random_reader(self):
def reader():
self.inputs = []
cnt = 0
while True:
tensors = fluid.LoDTensorArray()
in_data = np.random.uniform(
low=0, high=1, size=(1, self.in_size)).astype('float32')
tensors.append(as_tensor(in_data))
label = np.random.random_integers(
low=0, high=self.class_num - 1, size=(1, 1)).astype('int64')
tensors.append(as_tensor(label))
if cnt < self.iterations * self.batch_size * self.batch_size_times:
if cnt % (self.batch_size * self.batch_size_times) == 0:
self.inputs.append([in_data, label])
else:
self.inputs[-1][0] = np.concatenate(
(self.inputs[-1][0], in_data), axis=0)
self.inputs[-1][1] = np.concatenate(
(self.inputs[-1][1], label), axis=0)
elif not self.use_double_buffer:
break
yield tensors
cnt += 1
yield None
return reader
def main(self,
use_cuda=True,
use_parallel_executor=False,
use_double_buffer=False):
assert not use_cuda or use_cuda and core.is_compiled_with_cuda()
self.use_cuda = use_cuda
self.use_parallel_executor = use_parallel_executor
self.use_double_buffer = use_double_buffer
startup_program = fluid.Program()
main_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
in_data, label, loss, optimizer, feed_queue = simple_fc_net(
in_size=self.in_size,
class_num=self.class_num,
hidden_sizes=self.hidden_sizes,
batch_size=self.batch_size,
queue_capacity=self.queue_capacity,
use_double_buffer=self.use_double_buffer)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place)
startup_exe.run(startup_program)
if use_parallel_executor:
main_exe = fluid.ParallelExecutor(use_cuda, loss_name=loss.name)
if use_cuda:
self.batch_size_times = core.get_cuda_device_count()
else:
self.batch_size_times = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
main_exe = startup_exe
self.batch_size_times = 1
reader = self.random_reader()
thread = threading.Thread(
target=feed_data, args=(feed_queue, reader))
thread.start()
self.outputs = []
for _ in range(self.iterations):
fetches = main_exe.run(fetch_list=[in_data.name, label.name])
fetches = [as_numpy(fetch) for fetch in fetches]
self.outputs.append(fetches)
feed_queue.close()
self.validate()
def validate(self):
self.assertEqual(len(self.inputs), len(self.outputs))
for batch_in, batch_out in zip(self.inputs, self.outputs):
self.assertEqual(len(batch_in), len(batch_out))
if self.use_parallel_executor and not self.use_double_buffer:
self.validate_unordered_batch(batch_in, batch_out)
else:
for in_data, out_data in zip(batch_in, batch_out):
self.assertEqual(in_data.shape, out_data.shape)
if not self.use_parallel_executor:
self.assertTrue((in_data == out_data).all())
def validate_unordered_batch(self, batch_in, batch_out):
out_index_left_set = set(range(self.batch_size * self.batch_size_times))
mapping_num = 0
for i in range(self.batch_size * self.batch_size_times):
for j in out_index_left_set:
flag = True
for k in range(len(batch_in)):
in_data = batch_in[k][i]
out_data = batch_out[k][j]
if (in_data != out_data).any():
flag = False
break
if flag:
out_index_left_set.remove(j)
mapping_num += 1
break
self.assertEqual(mapping_num, self.batch_size * self.batch_size_times)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册