diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index eacedeea8835d27b712b287824b9d30b03ebebbf..db4e619e7b1cee2c934aa5da38a95c4a99c4ead3 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -38,8 +38,9 @@ class MultipleReader : public framework::ReaderBase { }; MultipleReader(const std::vector& file_names, - const std::vector& dims, size_t thread_num) - : file_names_(file_names), dims_(dims) { + const std::vector& dims, size_t thread_num, + size_t buffer_size) + : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) { prefetchers_.resize(thread_num); StartNewScheduler(); } @@ -60,6 +61,7 @@ class MultipleReader : public framework::ReaderBase { std::vector dims_; std::thread scheduler_; std::vector prefetchers_; + size_t buffer_size_; framework::Channel* waiting_file_idx_; framework::Channel* available_thread_idx_; framework::Channel>* buffer_; @@ -92,7 +94,7 @@ void MultipleReader::StartNewScheduler() { waiting_file_idx_ = framework::MakeChannel(file_names_.size()); available_thread_idx_ = framework::MakeChannel(thread_num); buffer_ = - framework::MakeChannel>(thread_num); + framework::MakeChannel>(buffer_size_); for (size_t i = 0; i < file_names_.size(); ++i) { waiting_file_idx_->Send(&i); @@ -197,11 +199,13 @@ class OpenFilesOp : public framework::OperatorBase { const auto& file_names = Attr>("file_names"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); const size_t thread_num = Attr("thread_num"); + const size_t buffer_size = Attr("buffer_size"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultipleReader( - file_names, RestoreShapes(shape_concat, ranks), thread_num)); + out->Reset(new MultipleReader(file_names, + RestoreShapes(shape_concat, ranks), + thread_num, buffer_size)); } }; @@ -212,6 +216,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { AddAttr>("file_names", "Files to be read."); AddAttr("thread_num", "The maximal concurrent prefetch thread number.") .GreaterThan(0); + AddAttr("buffer_size", "The size of prefetch buffer.").GreaterThan(0); AddComment(R"DOC( OpenFiles Operator diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index bd7e9c30fed2c38a206bf17a646d8a4433af4099..da5b4853d34241ad70d2d75f665dffb4bae1e821 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -287,7 +287,14 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) -def open_files(filenames, thread_num, shapes, lod_levels, dtypes): +def open_files(filenames, + shapes, + lod_levels, + dtypes, + thread_num, + buffer_size=None): + if buffer_size is None: + buffer_size = thread_num dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] shape_concat = [] ranks = [] @@ -308,7 +315,8 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): 'lod_levels': lod_levels, 'ranks': ranks, 'file_names': filenames, - 'thread_num': thread_num + 'thread_num': thread_num, + 'buffer_size': buffer_size }) startup_var.desc.set_dtypes(dtypes)