未验证 提交 c5023bea 编写于 作者: Y yuyang18

Merge branch 'feature/rewrite_open_files' into feature/combine_open_files_and_double_buffer

...@@ -12,8 +12,11 @@ ...@@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include <stdexcept>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
...@@ -21,141 +24,187 @@ namespace paddle { ...@@ -21,141 +24,187 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class MultiFileReader : public framework::ReaderBase { class IReaderContainer {
public:
virtual ~IReaderContainer() {}
virtual void AppendReader(
std::unique_ptr<framework::ReaderBase>&& readers) = 0;
virtual void Stop() = 0;
virtual void Start() = 0;
virtual void ReadNext(std::vector<framework::LoDTensor>* out) = 0;
};
class OrderedReaderContainer : public IReaderContainer {
public: public:
MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num, void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
size_t buffer_size) pending_.emplace(std::move(reader));
: buffer_size_(buffer_size) { }
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) { void Stop() override {
readers_.emplace_back(CreateReaderByFileName(f_name)); while (!pending_.empty()) {
MoveFrontPendingToDone();
} }
prefetchers_.resize(thread_num);
StartNewScheduler();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void Start() override { std::swap(done_, pending_); }
~MultiFileReader() { EndScheduler(); } void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!pending_.empty()) {
pending_.front()->ReadNext(out);
if (out->empty()) {
MoveFrontPendingToDone();
ReadNext(out);
}
} else {
out->clear();
}
}
private: private:
void ShutdownImpl() override { EndScheduler(); } void MoveFrontPendingToDone() {
pending_.front()->Shutdown();
void StartImpl() override { StartNewScheduler(); } pending_.front()->Start();
done_.emplace(move(pending_.front()));
void StartNewScheduler(); pending_.pop();
void EndScheduler(); }
void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); std::queue<std::unique_ptr<framework::ReaderBase>> pending_;
std::queue<std::unique_ptr<framework::ReaderBase>> done_;
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
reader::BlockingQueue<size_t>* waiting_reader_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { class PreemptiveReaderContainer : public IReaderContainer {
if (!buffer_->Receive(out)) { using ReaderList = std::list<std::unique_ptr<framework::ReaderBase>>;
out->clear();
struct FutureItem {
std::vector<framework::LoDTensor> data_;
ReaderList::iterator reader_it_;
std::exception_ptr exception_;
};
using FutureList = std::list<std::future<FutureItem>>;
public:
explicit PreemptiveReaderContainer(size_t thread_num) : pool_(thread_num) {}
void Stop() override {
if (!pending_.empty()) {
for (auto& reader : pending_) {
reader->Shutdown();
} }
} for (auto& fu : futures_) {
fu.wait();
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
buffer_size_);
for (size_t i = 0; i < readers_.size(); ++i) {
waiting_reader_idx_->Send(i);
}
waiting_reader_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_reader_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_reader_idx_;
}
void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t reader_idx;
if (waiting_reader_idx_->Receive(&reader_idx)) {
// Still have files to read. Start a new prefetch thread.
prefetcher = std::thread([this, reader_idx, thread_idx] {
PrefetchThreadFunc(reader_idx, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
} }
futures_.clear();
for (auto& reader : pending_) {
reader->Start();
done_.emplace_back(std::move(reader));
}
pending_.clear();
bool timeout;
complete_queue_.PopAll(1000, &timeout);
PADDLE_ENFORCE(!timeout);
} }
} }
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler void Start() override {
// to release their resource. So a check is needed before scheduler ends. for (auto& reader : done_) {
for (auto& p : prefetchers_) { AppendReader(std::move(reader));
if (p.joinable()) {
p.join();
} }
done_.clear();
} }
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts."; if (!pending_.empty()) {
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx]; auto future_it = complete_queue_.Pop();
while (true) { FutureItem item = future_it->get();
std::vector<framework::LoDTensor> ins; if (item.exception_) {
reader->ReadNext(&ins); for (auto it = futures_.begin(); it != futures_.end(); ++it) {
if (ins.empty()) { if (it != future_it) {
reader->Shutdown(); it->wait(); // Wait all other threads complete.
reader->Start(); }
break; }
std::rethrow_exception(item.exception_);
} else if (item.data_.empty()) { // reader done.
done_.emplace_back(std::move(*item.reader_it_));
pending_.erase(item.reader_it_);
futures_.erase(future_it);
ReadNext(out);
} else {
*out = item.data_;
// continue read async
AsyncRead(item.reader_it_, &future_it);
}
} else {
out->clear();
}
} }
private:
void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
pending_.emplace_back(std::move(reader));
auto reader_it = pending_.end();
--reader_it;
futures_.emplace_back();
auto future_it = futures_.end();
--future_it;
AsyncRead(reader_it, &future_it);
}
void AsyncRead(const ReaderList::iterator& reader_it,
FutureList::iterator* future_it_ptr) {
auto& future_it = *future_it_ptr;
*future_it = pool_.enqueue([reader_it, future_it, this] {
try { try {
buffer_->Send(std::move(ins)); FutureItem item;
} catch (paddle::platform::EnforceNotMet e) { item.reader_it_ = reader_it;
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " (*reader_it)->ReadNext(&item.data_);
"thread of file idx '" if (item.data_.empty()) {
<< reader_idx << "' will terminate."; (*reader_it)->Shutdown();
break; (*reader_it)->Start();
}
complete_queue_.Push(future_it);
return item;
} catch (...) {
FutureItem item;
item.exception_ = std::current_exception();
complete_queue_.Push(future_it);
return item;
}
});
}
FutureList futures_;
ThreadPool pool_;
framework::BlockingQueue<FutureList::iterator> complete_queue_;
std::list<std::unique_ptr<framework::ReaderBase>> pending_;
std::list<std::unique_ptr<framework::ReaderBase>> done_;
};
class MultiFileReader : public framework::ReaderBase {
public:
MultiFileReader(const std::vector<std::string>& file_names,
std::unique_ptr<IReaderContainer>&& container)
: container_(std::move(container)) {
for (auto& fn : file_names) {
container_->AppendReader(CreateReaderByFileName(fn));
} }
} }
if (!available_thread_idx_->Send(thread_idx)) { ~MultiFileReader() { container_->Stop(); }
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx."; protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
container_->ReadNext(out);
} }
VLOG(5) << "The prefetch thread of file idx '" << reader_idx void ShutdownImpl() override { container_->Stop(); }
<< "' terminates."; void StartImpl() override { container_->Start(); }
}
private:
std::unique_ptr<IReaderContainer> container_;
};
class OpenFilesOp : public framework::OperatorBase { class OpenFilesOp : public framework::OperatorBase {
public: public:
...@@ -173,13 +222,22 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -173,13 +222,22 @@ class OpenFilesOp : public framework::OperatorBase {
"shape concat's length."); "shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names"); const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num"); bool is_test = Attr<bool>("is_test");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
std::unique_ptr<IReaderContainer> container;
if (is_test) {
container.reset(new OrderedReaderContainer());
} else {
container.reset(new PreemptiveReaderContainer(
std::min(file_names.size(),
static_cast<size_t>(std::thread::hardware_concurrency()))));
}
out->Reset( out->Reset(
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size)); std::make_shared<MultiFileReader>(file_names, std::move(container)));
} }
}; };
...@@ -187,9 +245,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -187,9 +245,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
protected: protected:
void Apply() override { void Apply() override {
AddAttr<std::vector<std::string>>("file_names", "Files to be read."); AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.") AddAttr<bool>("is_test", "Used for testing data.").SetDefault(false);
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
......
...@@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream) ...@@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
Scanner::Scanner(const std::string &filename) Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) { : stream_(new std::ifstream(filename)), parser_(*stream_) {
PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename);
Reset(); Reset();
} }
......
...@@ -20,6 +20,7 @@ from control_flow import BlockGuard ...@@ -20,6 +20,7 @@ from control_flow import BlockGuard
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..executor import global_scope from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc from layer_function_generator import generate_layer_fn, templatedoc
import sys
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
...@@ -532,10 +533,10 @@ def open_files(filenames, ...@@ -532,10 +533,10 @@ def open_files(filenames,
shapes, shapes,
lod_levels, lod_levels,
dtypes, dtypes,
thread_num=1, thread_num=None,
buffer_size=None, buffer_size=None,
pass_num=1, pass_num=1,
for_parallel=True): is_test=None):
""" """
Open files Open files
...@@ -548,14 +549,15 @@ def open_files(filenames, ...@@ -548,14 +549,15 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number. thread_num(None): Deprecated argument. It will be set by open_files
buffer_size(int|None): The size of prefetch buffer. If it is setted None, automatically.
buffer size will be thread_num * 3. buffer_size(None): Deprecated argument. It will be set by open_files
Default: None automatically.
pass_num(int): Number of passes to run. pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run is_test(bool|None): Whether `open_files` used for testing or not. If it
subsequent operators in parallel. is used for testing, the order of data generated is same as the file
Default: True order. Otherwise, it is not guaranteed the order of data is same
between every epoch. [Default: False].
Returns: Returns:
Variable: A Reader Variable via which we can get file data. Variable: A Reader Variable via which we can get file data.
...@@ -567,15 +569,20 @@ def open_files(filenames, ...@@ -567,15 +569,20 @@ def open_files(filenames,
'./data2.recordio'], './data2.recordio'],
shapes=[(3,224,224), (1)], shapes=[(3,224,224), (1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data: # Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader) image, label = fluid.layers.io.read_file(reader)
""" """
if buffer_size is None: if thread_num is not None:
buffer_size = thread_num * 3 print >> sys.stderr, "thread_num parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if buffer_size is not None:
print >> sys.stderr, "buffer_size parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if isinstance(filenames, basestring): if isinstance(filenames, basestring):
filenames = [filenames] filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
...@@ -589,17 +596,16 @@ def open_files(filenames, ...@@ -589,17 +596,16 @@ def open_files(filenames,
multi_file_reader_name = unique_name('multi_file_reader') multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_reader = startup_blk.create_var(name=multi_file_reader_name) startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op( attrs = {
type='open_files',
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'ranks': ranks, 'ranks': ranks,
'file_names': filenames, 'file_names': filenames
'thread_num': thread_num, }
'buffer_size': buffer_size if is_test is not None:
}) attrs['is_test'] = is_test
startup_blk.append_op(
type='open_files', outputs={'Out': [startup_reader]}, attrs=attrs)
startup_reader.desc.set_dtypes(dtypes) startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True startup_reader.persistable = True
......
...@@ -31,7 +31,10 @@ def load_vocab(filename): ...@@ -31,7 +31,10 @@ def load_vocab(filename):
# load word dict with paddle inner function # load word dict with paddle inner function
word_dict = load_vocab(sys.argv[1]) if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict()
else:
word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict) word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict) print "Dict dim = ", len(word_dict)
......
...@@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100): ...@@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100):
pass_num=pass_num, pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]], shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['int64', 'int64'], dtypes=['int64', 'int64'])
thread_num=1)
test_file_obj = fluid.layers.open_files( test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES, filenames=TEST_FILES,
pass_num=1, pass_num=1,
shapes=[[-1, 1], [-1, 1]], shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['int64', 'int64'], dtypes=['int64', 'int64'])
thread_num=1)
if is_train: if is_train:
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000) file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)
......
...@@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase): ...@@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase):
copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num): def main(self, is_test=False):
file_list = [ file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
] ]
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files( data_files = fluid.layers.open_files(
filenames=file_list, filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)], shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
is_test=is_test)
img, label = fluid.layers.read_file(data_files) img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase): ...@@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase):
self.assertEqual(batch_count, self.num_batch * 3) self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self): def test_main(self):
self.main(thread_num=3) # thread number equals to file number self.main(is_test=False)
self.main(thread_num=10) # thread number is larger than file number self.main(is_test=True)
self.main(thread_num=2) # thread number is less than file number
if __name__ == '__main__':
unittest.main()
...@@ -32,9 +32,7 @@ def simple_fc_net(use_feed): ...@@ -32,9 +32,7 @@ def simple_fc_net(use_feed):
filenames=[MNIST_RECORDIO_FILE], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader) reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
...@@ -60,9 +58,7 @@ def fc_with_batchnorm(use_feed): ...@@ -60,9 +58,7 @@ def fc_with_batchnorm(use_feed):
filenames=[MNIST_RECORDIO_FILE], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader) reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册