提交 af242901 编写于 作者: F fengjiayi

Add 'buffer_size' api for open_files op

上级 f2c0b886
......@@ -38,8 +38,9 @@ class MultipleReader : public framework::ReaderBase {
};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
......@@ -60,6 +61,7 @@ class MultipleReader : public framework::ReaderBase {
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
......@@ -92,7 +94,7 @@ void MultipleReader::StartNewScheduler() {
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
......@@ -197,11 +199,13 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
out->Reset(new MultipleReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
}
};
......@@ -212,6 +216,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC(
OpenFiles Operator
......
......@@ -287,7 +287,14 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num,
buffer_size=None):
if buffer_size is None:
buffer_size = thread_num
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
......@@ -308,7 +315,8 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
'thread_num': thread_num,
'buffer_size': buffer_size
})
startup_var.desc.set_dtypes(dtypes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册