未验证 提交 c5023bea 编写于 作者: Y yuyang18

Merge branch 'feature/rewrite_open_files' into feature/combine_open_files_and_double_buffer

......@@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <stdexcept>
#include <thread> // NOLINT
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
......@@ -21,141 +24,187 @@ namespace paddle {
namespace operators {
namespace reader {
class MultiFileReader : public framework::ReaderBase {
class IReaderContainer {
public:
virtual ~IReaderContainer() {}
virtual void AppendReader(
std::unique_ptr<framework::ReaderBase>&& readers) = 0;
virtual void Stop() = 0;
virtual void Start() = 0;
virtual void ReadNext(std::vector<framework::LoDTensor>* out) = 0;
};
class OrderedReaderContainer : public IReaderContainer {
public:
MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
size_t buffer_size)
: buffer_size_(buffer_size) {
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name));
void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
pending_.emplace(std::move(reader));
}
void Stop() override {
while (!pending_.empty()) {
MoveFrontPendingToDone();
}
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void Start() override { std::swap(done_, pending_); }
~MultiFileReader() { EndScheduler(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!pending_.empty()) {
pending_.front()->ReadNext(out);
if (out->empty()) {
MoveFrontPendingToDone();
ReadNext(out);
}
} else {
out->clear();
}
}
private:
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
reader::BlockingQueue<size_t>* waiting_reader_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
void MoveFrontPendingToDone() {
pending_.front()->Shutdown();
pending_.front()->Start();
done_.emplace(move(pending_.front()));
pending_.pop();
}
std::queue<std::unique_ptr<framework::ReaderBase>> pending_;
std::queue<std::unique_ptr<framework::ReaderBase>> done_;
};
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) {
out->clear();
class PreemptiveReaderContainer : public IReaderContainer {
using ReaderList = std::list<std::unique_ptr<framework::ReaderBase>>;
struct FutureItem {
std::vector<framework::LoDTensor> data_;
ReaderList::iterator reader_it_;
std::exception_ptr exception_;
};
using FutureList = std::list<std::future<FutureItem>>;
public:
explicit PreemptiveReaderContainer(size_t thread_num) : pool_(thread_num) {}
void Stop() override {
if (!pending_.empty()) {
for (auto& reader : pending_) {
reader->Shutdown();
}
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
buffer_size_);
for (size_t i = 0; i < readers_.size(); ++i) {
waiting_reader_idx_->Send(i);
}
waiting_reader_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_reader_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_reader_idx_;
}
void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t reader_idx;
if (waiting_reader_idx_->Receive(&reader_idx)) {
// Still have files to read. Start a new prefetch thread.
prefetcher = std::thread([this, reader_idx, thread_idx] {
PrefetchThreadFunc(reader_idx, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
for (auto& fu : futures_) {
fu.wait();
}
futures_.clear();
for (auto& reader : pending_) {
reader->Start();
done_.emplace_back(std::move(reader));
}
pending_.clear();
bool timeout;
complete_queue_.PopAll(1000, &timeout);
PADDLE_ENFORCE(!timeout);
}
}
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
void Start() override {
for (auto& reader : done_) {
AppendReader(std::move(reader));
}
done_.clear();
}
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
while (true) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->Shutdown();
reader->Start();
break;
void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!pending_.empty()) {
auto future_it = complete_queue_.Pop();
FutureItem item = future_it->get();
if (item.exception_) {
for (auto it = futures_.begin(); it != futures_.end(); ++it) {
if (it != future_it) {
it->wait(); // Wait all other threads complete.
}
}
std::rethrow_exception(item.exception_);
} else if (item.data_.empty()) { // reader done.
done_.emplace_back(std::move(*item.reader_it_));
pending_.erase(item.reader_it_);
futures_.erase(future_it);
ReadNext(out);
} else {
*out = item.data_;
// continue read async
AsyncRead(item.reader_it_, &future_it);
}
} else {
out->clear();
}
}
private:
void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
pending_.emplace_back(std::move(reader));
auto reader_it = pending_.end();
--reader_it;
futures_.emplace_back();
auto future_it = futures_.end();
--future_it;
AsyncRead(reader_it, &future_it);
}
void AsyncRead(const ReaderList::iterator& reader_it,
FutureList::iterator* future_it_ptr) {
auto& future_it = *future_it_ptr;
*future_it = pool_.enqueue([reader_it, future_it, this] {
try {
buffer_->Send(std::move(ins));
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file idx '"
<< reader_idx << "' will terminate.";
break;
FutureItem item;
item.reader_it_ = reader_it;
(*reader_it)->ReadNext(&item.data_);
if (item.data_.empty()) {
(*reader_it)->Shutdown();
(*reader_it)->Start();
}
complete_queue_.Push(future_it);
return item;
} catch (...) {
FutureItem item;
item.exception_ = std::current_exception();
complete_queue_.Push(future_it);
return item;
}
});
}
FutureList futures_;
ThreadPool pool_;
framework::BlockingQueue<FutureList::iterator> complete_queue_;
std::list<std::unique_ptr<framework::ReaderBase>> pending_;
std::list<std::unique_ptr<framework::ReaderBase>> done_;
};
class MultiFileReader : public framework::ReaderBase {
public:
MultiFileReader(const std::vector<std::string>& file_names,
std::unique_ptr<IReaderContainer>&& container)
: container_(std::move(container)) {
for (auto& fn : file_names) {
container_->AppendReader(CreateReaderByFileName(fn));
}
}
if (!available_thread_idx_->Send(thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
~MultiFileReader() { container_->Stop(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
container_->ReadNext(out);
}
VLOG(5) << "The prefetch thread of file idx '" << reader_idx
<< "' terminates.";
}
void ShutdownImpl() override { container_->Stop(); }
void StartImpl() override { container_->Start(); }
private:
std::unique_ptr<IReaderContainer> container_;
};
class OpenFilesOp : public framework::OperatorBase {
public:
......@@ -173,13 +222,22 @@ class OpenFilesOp : public framework::OperatorBase {
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
bool is_test = Attr<bool>("is_test");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
std::unique_ptr<IReaderContainer> container;
if (is_test) {
container.reset(new OrderedReaderContainer());
} else {
container.reset(new PreemptiveReaderContainer(
std::min(file_names.size(),
static_cast<size_t>(std::thread::hardware_concurrency()))));
}
out->Reset(
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
std::make_shared<MultiFileReader>(file_names, std::move(container)));
}
};
......@@ -187,9 +245,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
protected:
void Apply() override {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddAttr<bool>("is_test", "Used for testing data.").SetDefault(false);
AddComment(R"DOC(
OpenFiles Operator
......
......@@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) {
PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename);
Reset();
}
......
......@@ -20,6 +20,7 @@ from control_flow import BlockGuard
from ..layer_helper import LayerHelper
from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
import sys
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
......@@ -532,10 +533,10 @@ def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num=1,
thread_num=None,
buffer_size=None,
pass_num=1,
for_parallel=True):
is_test=None):
"""
Open files
......@@ -548,14 +549,15 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int|None): The size of prefetch buffer. If it is setted None,
buffer size will be thread_num * 3.
Default: None
thread_num(None): Deprecated argument. It will be set by open_files
automatically.
buffer_size(None): Deprecated argument. It will be set by open_files
automatically.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Default: True
is_test(bool|None): Whether `open_files` used for testing or not. If it
is used for testing, the order of data generated is same as the file
order. Otherwise, it is not guaranteed the order of data is same
between every epoch. [Default: False].
Returns:
Variable: A Reader Variable via which we can get file data.
......@@ -567,15 +569,20 @@ def open_files(filenames,
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num * 3
if thread_num is not None:
print >> sys.stderr, "thread_num parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if buffer_size is not None:
print >> sys.stderr, "buffer_size parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
......@@ -589,17 +596,16 @@ def open_files(filenames,
multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block()
startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_reader]},
attrs={
attrs = {
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num,
'buffer_size': buffer_size
})
'file_names': filenames
}
if is_test is not None:
attrs['is_test'] = is_test
startup_blk.append_op(
type='open_files', outputs={'Out': [startup_reader]}, attrs=attrs)
startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True
......
......@@ -31,7 +31,10 @@ def load_vocab(filename):
# load word dict with paddle inner function
word_dict = load_vocab(sys.argv[1])
if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict()
else:
word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict)
......
......@@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100):
pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
dtypes=['int64', 'int64'])
test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES,
pass_num=1,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
dtypes=['int64', 'int64'])
if is_train:
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)
......
......@@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase):
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num):
def main(self, is_test=False):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
is_test=is_test)
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
......@@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase):
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
self.main(is_test=False)
self.main(is_test=True)
if __name__ == '__main__':
unittest.main()
......@@ -32,9 +32,7 @@ def simple_fc_net(use_feed):
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
......@@ -60,9 +58,7 @@ def fc_with_batchnorm(use_feed):
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册