diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 744bd3b7ef71f83ad82979eb966369c2e9456a7d..4a43fc02d2189ec2eb4bff12769c11cb8dae8193 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,6 +15,7 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index bd0bb2ee3b0252f47318c59d9940d8dd478723de..76cdb794ccdb4a015ae8630940a5c26845e7a7b3 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + if (local_buffer_.payloads_.empty()) { buffer_->Receive(&local_buffer_); } - *out = local_buffer_.payloads_; local_buffer_.payloads_.clear(); if (local_buffer_.ctx_) { diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..414c76fea0bb916dfeafe38c0448a7a800889e03 --- /dev/null +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultipleReader : public framework::ReaderBase { + public: + MultipleReader(const std::vector& file_names, + const std::vector& dims, size_t thread_num) + : file_names_(file_names), dims_(dims) { + prefetchers_.resize(thread_num); + StartNewScheduler(); + } + + void ReadNext(std::vector* out) override; + bool HasNext() const override; + void ReInit() override; + + ~MultipleReader() { EndScheduler(); } + + private: + void StartNewScheduler(); + void EndScheduler(); + void ScheduleThreadFunc(); + void PrefetchThreadFunc(std::string file_name, size_t thread_idx); + + std::vector file_names_; + std::vector dims_; + std::thread scheduler_; + std::vector prefetchers_; + framework::Channel* waiting_file_idx_; + framework::Channel* available_thread_idx_; + framework::Channel>* buffer_; + mutable std::vector local_buffer_; +}; + +void MultipleReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + + if (local_buffer_.empty()) { + buffer_->Receive(&local_buffer_); + } + *out = local_buffer_; + local_buffer_.clear(); +} + +bool MultipleReader::HasNext() const { + return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; +} + +void MultipleReader::ReInit() { + EndScheduler(); + local_buffer_.clear(); + StartNewScheduler(); +} + +void MultipleReader::StartNewScheduler() { + size_t thread_num = prefetchers_.size(); + waiting_file_idx_ = framework::MakeChannel(file_names_.size()); + available_thread_idx_ = framework::MakeChannel(thread_num); + buffer_ = + framework::MakeChannel>(thread_num); + + for (size_t i = 0; i < file_names_.size(); ++i) { + waiting_file_idx_->Send(&i); + } + waiting_file_idx_->Close(); + for (size_t i = 0; i < thread_num; ++i) { + available_thread_idx_->Send(&i); + } + + scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); +} + +void MultipleReader::EndScheduler() { + available_thread_idx_->Close(); + buffer_->Close(); + waiting_file_idx_->Close(); + if (scheduler_.joinable()) { + scheduler_.join(); + } + delete buffer_; + delete available_thread_idx_; + delete waiting_file_idx_; +} + +void MultipleReader::ScheduleThreadFunc() { + VLOG(5) << "MultipleReader schedule thread starts."; + size_t completed_thread_num = 0; + size_t thread_idx; + while (available_thread_idx_->Receive(&thread_idx)) { + std::thread& prefetcher = prefetchers_[thread_idx]; + if (prefetcher.joinable()) { + prefetcher.join(); + } + size_t file_idx; + if (waiting_file_idx_->Receive(&file_idx)) { + // Still have files to read. Start a new prefetch thread. + std::string file_name = file_names_[file_idx]; + prefetcher = std::thread([this, file_name, thread_idx] { + PrefetchThreadFunc(file_name, thread_idx); + }); + } else { + // No more file to read. + ++completed_thread_num; + if (completed_thread_num == prefetchers_.size()) { + buffer_->Close(); + break; + } + } + } + // If users invoke ReInit() when scheduler is running, it will close the + // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler + // to release their resource. So a check is needed before scheduler ends. + for (auto& p : prefetchers_) { + if (p.joinable()) { + p.join(); + } + } + VLOG(5) << "MultipleReader schedule thread terminates."; +} + +void MultipleReader::PrefetchThreadFunc(std::string file_name, + size_t thread_idx) { + VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; + std::unique_ptr reader = + CreateReaderByFileName(file_name, dims_); + while (reader->HasNext()) { + std::vector ins; + reader->ReadNext(&ins); + if (!buffer_->Send(&ins)) { + VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " + "thread of file '" + << file_name << "' will terminate."; + break; + } + } + if (!available_thread_idx_->Send(&thread_idx)) { + VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " + "Fail to send thread_idx."; + } + VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; +} + +class OpenFilesOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + const auto& file_names = Attr>("file_names"); + PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); + const size_t thread_num = Attr("thread_num"); + + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new MultipleReader( + file_names, RestoreShapes(shape_concat, ranks), thread_num)); + } +}; + +class OpenFilesOpMaker : public FileReaderMakerBase { + public: + OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + + AddComment(R"DOC( + OpenFiles Operator + + An OpenFilesOp creates a MultipleReader, which is able to + read data multi-threaded from multiple files. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, + reader::OpenFilesOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 0ba4f3854431742eb354f8c90eb395f5d7b32b2e..fc8dc747ff0c2286f4516d8350f75d9887361924 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -36,6 +36,21 @@ std::unordered_map& FileReaderRegistry() { return regs; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_name].[file_format] (e.g., 'data_file.recordio')."); + std::string filetype = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(file_name, dims); + return std::unique_ptr(reader); +} + FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 58f9b4ba35546571fd3b1d0c3ce128f18e248f01..929d32ad8b367865e33530f8517343c513ee9878 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,6 +21,8 @@ namespace paddle { namespace operators { namespace reader { +static constexpr char kFileFormatSeparator[] = "."; + using FileReaderCreator = std::function&)>; @@ -29,12 +31,15 @@ std::unordered_map& FileReaderRegistry(); template int RegisterFileReader(const std::string& filetype) { FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector& dim) { - return new Reader(fn, dim); + const std::string& fn, const std::vector& dims) { + return new Reader(fn, dims); }; return 0; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims); + extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9c91f395e7c9d7ca76c1a5cc310bc3bbc06daec9..f169642eaa44eadcef8ff0bc6745183a9fec20e8 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'open_files', 'read_file', 'create_shuffle_reader', + 'create_double_buffer_reader' ] @@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) +def open_files(filenames, thread_num, shapes, lod_levels, dtypes): + dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] + shape_concat = [] + ranks = [] + + for shape in shapes: + shape_concat.extend(shape) + ranks.append(len(shape)) + + var_name = unique_name('multiple_reader') + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type='open_files', + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'ranks': ranks, + 'file_names': filenames, + 'thread_num': thread_num + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + def __create_decorated_reader__(op_type, reader, attrs): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 6b3fc2a83c649c28d21c9a8a0b35c2f2fa04f269..ad02bdecf436bba925e2e3b7efb20c878df70dfd 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -1 +1,4 @@ mnist.recordio +mnist_0.recordio +mnist_1.recordio +mnist_2.recordio diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..69f8acf81efaba8fc0f3df4cfe3a42dc4e477df2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -0,0 +1,74 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist +from shutil import copyfile + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + self.batch_size = 64 + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist_0.recordio', reader, feeder) + copyfile('./mnist_0.recordio', './mnist_1.recordio') + copyfile('./mnist_0.recordio', './mnist_2.recordio') + + def main(self, thread_num): + file_list = [ + './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' + ] + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_files = fluid.layers.open_files( + filenames=file_list, + thread_num=thread_num, + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(data_files) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_files.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_files.reset() + self.assertEqual(batch_count, self.num_batch * 3) + + def test_main(self): + self.main(thread_num=3) # thread number equals to file number + self.main(thread_num=10) # thread number is larger than file number + self.main(thread_num=2) # thread number is less than file number