提交 af242901 编写于 作者: F fengjiayi

Add 'buffer_size' api for open_files op

上级 f2c0b886
...@@ -38,8 +38,9 @@ class MultipleReader : public framework::ReaderBase { ...@@ -38,8 +38,9 @@ class MultipleReader : public framework::ReaderBase {
}; };
MultipleReader(const std::vector<std::string>& file_names, MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num) const std::vector<framework::DDim>& dims, size_t thread_num,
: file_names_(file_names), dims_(dims) { size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
...@@ -60,6 +61,7 @@ class MultipleReader : public framework::ReaderBase { ...@@ -60,6 +61,7 @@ class MultipleReader : public framework::ReaderBase {
std::vector<framework::DDim> dims_; std::vector<framework::DDim> dims_;
std::thread scheduler_; std::thread scheduler_;
std::vector<std::thread> prefetchers_; std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_; framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_; framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
...@@ -92,7 +94,7 @@ void MultipleReader::StartNewScheduler() { ...@@ -92,7 +94,7 @@ void MultipleReader::StartNewScheduler() {
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num); available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ = buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num); framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) { for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i); waiting_file_idx_->Send(&i);
...@@ -197,11 +199,13 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -197,11 +199,13 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& file_names = Attr<std::vector<std::string>>("file_names"); const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num"); const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader( out->Reset(new MultipleReader(file_names,
file_names, RestoreShapes(shape_concat, ranks), thread_num)); RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
} }
}; };
...@@ -212,6 +216,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -212,6 +216,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr<std::vector<std::string>>("file_names", "Files to be read."); AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.") AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0); .GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
......
...@@ -287,7 +287,14 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -287,7 +287,14 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var) startup_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes): def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num,
buffer_size=None):
if buffer_size is None:
buffer_size = thread_num
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = [] shape_concat = []
ranks = [] ranks = []
...@@ -308,7 +315,8 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): ...@@ -308,7 +315,8 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'ranks': ranks, 'ranks': ranks,
'file_names': filenames, 'file_names': filenames,
'thread_num': thread_num 'thread_num': thread_num,
'buffer_size': buffer_size
}) })
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册