diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index f288b90b4dbaff9b286b7b74d8cdf7d5a5963650..0f2f4387aa4326f3571d667513930aec8027cf06 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -16,8 +16,29 @@ namespace paddle { namespace framework { + +void ReaderBase::ReadNext(std::vector *out) { + if (status_ != ReaderStatus::kRunning) { + PADDLE_THROW("The reader is not at the status of 'running'."); + } + ReadNextImpl(out); +} + +void ReaderBase::Shutdown() { + if (status_ != ReaderStatus::kStopped) { + ShutdownImpl(); + status_ = ReaderStatus::kStopped; + } +} + +void ReaderBase::Start() { + if (status_ != ReaderStatus::kRunning) { + StartImpl(); + status_ = ReaderStatus::kRunning; + } +} + ReaderBase::~ReaderBase() {} -void FileReader::ReadNext(std::vector *out) { ReadNextImpl(out); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 823d58af5ea010c797dd2078759509fbe30f3a6a..6b62d11802bd4c819374c0ba4432bbd9290377dc 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -24,13 +24,26 @@ namespace paddle { namespace framework { +enum ReaderStatus { kRunning, kStopped }; + class ReaderBase { public: - virtual void ReadNext(std::vector* out) = 0; + void ReadNext(std::vector* out); + + void Shutdown(); - virtual void ReInit() = 0; + void Start(); virtual ~ReaderBase(); + + protected: + virtual void ReadNextImpl(std::vector* out) = 0; + + virtual void ShutdownImpl() = 0; + + virtual void StartImpl() = 0; + + std::atomic status_{kStopped}; }; class DecoratedReader : public ReaderBase { @@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase { PADDLE_ENFORCE_NOT_NULL(reader_); } - void ReInit() override { reader_->ReInit(); } - protected: + void ShutdownImpl() override { reader_->Shutdown(); } + + void StartImpl() override { reader_->Start(); } + std::shared_ptr reader_; }; @@ -50,10 +65,10 @@ class FileReader : public ReaderBase { public: FileReader() : ReaderBase() {} - void ReadNext(std::vector* out) override; - protected: - virtual void ReadNextImpl(std::vector* out) = 0; + void ShutdownImpl() override {} + + void StartImpl() override {} }; // The ReaderHolder is used as reader' unified wrapper, @@ -68,9 +83,19 @@ class ReaderHolder { PADDLE_ENFORCE_NOT_NULL(reader_); reader_->ReadNext(out); } - void ReInit() { + + void ResetAll() { + // TODO(fengjiayi): The interface of reseting all. + } + + void Shutdown() { + PADDLE_ENFORCE_NOT_NULL(reader_); + reader_->Shutdown(); + } + + void Start() { PADDLE_ENFORCE_NOT_NULL(reader_); - reader_->ReInit(); + reader_->Start(); } private: diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index ecbae3894d551186f53625a6cc9cfdb36adc8d2d..429313a339fe022d3bacc5da1fbc1ddb1f35241f 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader { BatchReader(const std::shared_ptr& reader, int batch_size) : DecoratedReader(reader), batch_size_(batch_size) { buffer_.reserve(batch_size_); + Start(); } - void ReadNext(std::vector* out) override; + void ReadNextImpl(std::vector* out) override; private: int batch_size_; @@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { } }; -void BatchReader::ReadNext(std::vector* out) { +void BatchReader::ReadNextImpl(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a75c6d4c567ac93f37b38070421133af305f20a3..334ba4cf6e811e370924d0536aad595691e49f24 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader { sub_block_id_(sub_block.ID()), exe_(framework::Executor(platform::CPUPlace())), source_var_names_(source_var_names), - sink_var_names_(sink_var_names) {} + sink_var_names_(sink_var_names) { + Start(); + } - void ReadNext(std::vector* out) override; + void ReadNextImpl(std::vector* out) override; private: const framework::ProgramDesc program_; @@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference { } }; -void CustomReader::ReadNext(std::vector* out) { +void CustomReader::ReadNextImpl(std::vector* out) { out->clear(); std::vector underlying_outs; reader_->ReadNext(&underlying_outs); diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 0d2ff2e8e4882dc8c6572d173e85e695c42898b4..88c34187ea76305a0c0965d23acc30bc7554df24 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader { } } #endif - StartPrefetcher(); + Start(); } - void ReadNext(std::vector* out) override; - void ReInit() override; + void ReadNextImpl(std::vector* out) override; - ~DoubleBufferReader() { EndPrefetcher(); } + ~DoubleBufferReader() { Shutdown(); } private: + void ShutdownImpl() override { + EndPrefetcher(); + reader_->Shutdown(); + } + + void StartImpl() override { + reader_->Start(); + StartPrefetcher(); + } + void StartPrefetcher() { channel_ = new reader::BlockingQueue(kChannelSize); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); @@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { } }; -void DoubleBufferReader::ReadNext(std::vector* out) { +void DoubleBufferReader::ReadNextImpl(std::vector* out) { size_t cached_tensor_id; if (channel_->Receive(&cached_tensor_id)) { if (platform::is_gpu_place(place_)) { @@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector* out) { } } -void DoubleBufferReader::ReInit() { - EndPrefetcher(); - reader_->ReInit(); - StartPrefetcher(); -} - void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; size_t cached_tensor_id = 0; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 19b54110b9aeece33b8d6c73612ae0e12dbfafbd..7c8aa975a15ef9ff33554cc9452b037c45f7658e 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -22,25 +22,28 @@ namespace reader { class MultiPassReader : public framework::DecoratedReader { public: MultiPassReader(const std::shared_ptr& reader, int pass_num) - : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} + : DecoratedReader(reader), pass_num_(pass_num) { + Start(); + } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { reader_->ReadNext(out); if (out->empty()) { ++pass_count_; if (pass_count_ < pass_num_) { - reader_->ReInit(); + reader_->Shutdown(); + reader_->Start(); reader_->ReadNext(out); } } } - void ReInit() override { + private: + void StartImpl() override { pass_count_ = 0; - reader_->ReInit(); + reader_->Start(); } - private: int pass_num_; mutable int pass_count_; }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 84ea72379b406da2c91504c985a523ce663cc75d..9b4c6412e65d173e156e062ec416c7bc43995c9d 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -33,9 +33,13 @@ class PyReader : public framework::FileReader { if (!success) out->clear(); } - void ReInit() override {} - private: + void ShutdownImpl() override { /* TODO */ + } + + void StartImpl() override { /* TODO */ + } + std::shared_ptr queue_; }; diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 7cbc2882fdd6aa4ab80c9360014343d670734749..c92a8b49b58bf10da1f1c569d947a4ea699f6c19 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader { unsigned int seed = std::random_device()(); engine_.seed(seed); dist_ = std::uniform_real_distribution(low_, high_); + Start(); } void ReadNextImpl(std::vector* out) override { @@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader { } } - void ReInit() override { return; } - private: float low_; float high_; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index c032acdffdaacc15d99dd36048d32faf0468e254..7a44bd14eb2b8f052f71fc0782ae1d1420b7b61c 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader { mutex_.reset(new std::mutex()); } LOG(INFO) << "Creating file reader" << filename; + Start(); } - void ReInit() override { scanner_.Reset(); } - protected: void ReadNextImpl(std::vector* out) override { if (ThreadSafe) { @@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader { } } + void ShutdownImpl() override { scanner_.Reset(); } + private: std::unique_ptr mutex_; recordio::Scanner scanner_; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 57e8e21214b7c99e52550fe51a67c9b5201cb46f..3cee9bfd6439fcedd3d987bf11bbe17befdeafc0 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader { std::random_device device; seed_ = device(); } - ReloadBuffer(); + Start(); } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { out->clear(); if (iteration_pos_ >= buffer_.size()) { VLOG(10) << "Resetting shuffle buffer"; @@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader { } private: + void ShutdownImpl() override { + buffer_.clear(); + iteration_pos_ = 0; + reader_->Shutdown(); + } + + void StartImpl() override { + reader_->Start(); + ReloadBuffer(); + } + void ReloadBuffer() { buffer_.clear(); buffer_.reserve(buffer_size_); diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 3798015146f4ffb085aa82e23ca3f1fb3c5cf5a4..76b853527cd99ee3f4f8e434c58ce23bdc0e6eca 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -22,16 +22,26 @@ namespace reader { class ThreadedReader : public framework::DecoratedReader { public: explicit ThreadedReader(const std::shared_ptr& reader) - : DecoratedReader(reader) {} + : DecoratedReader(reader) { + Start(); + } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { std::lock_guard lock(mutex_); reader_->ReadNext(out); } - void ReInit() override { reader_->ReInit(); } - private: + void ShutdownImpl() override { + std::lock_guard lock(mutex_); + reader_->Shutdown(); + } + + void StartImpl() override { + std::lock_guard lock(mutex_); + reader_->Start(); + } + std::mutex mutex_; }; @@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { This operator creates a threaded reader. A threaded reader's 'ReadNext()' can be invoked by several threads at the same time. - When the attribute 'safe_mode' is true, the threaded reader's - 'ReInit()' is disabled to avoid unexpected bugs in multi-thread - environment. )DOC"); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 6e34a5bd008a0d24df25a93a49a965c31cbb758c..85127d93b20a985fe7daba3f3f74e45665bc2990 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase { readers_.emplace_back(CreateReaderByFileName(f_name)); } prefetchers_.resize(thread_num); - StartNewScheduler(); + Start(); } - void ReadNext(std::vector* out) override; - void ReInit() override; + void ReadNextImpl(std::vector* out) override; - ~MultiFileReader() { EndScheduler(); } + ~MultiFileReader() { Shutdown(); } private: - void StartNewScheduler(); - void EndScheduler(); + void StartImpl() override; + void ShutdownImpl() override; void ScheduleThreadFunc(); void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); @@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase { reader::BlockingQueue>* buffer_; }; -void MultiFileReader::ReadNext(std::vector* out) { +void MultiFileReader::ReadNextImpl(std::vector* out) { if (!buffer_->Receive(out)) { out->clear(); } } -void MultiFileReader::ReInit() { - EndScheduler(); - StartNewScheduler(); -} - -void MultiFileReader::StartNewScheduler() { +void MultiFileReader::StartImpl() { size_t thread_num = prefetchers_.size(); waiting_reader_idx_ = new reader::BlockingQueue(readers_.size()); available_thread_idx_ = new reader::BlockingQueue(thread_num); @@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() { scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); } -void MultiFileReader::EndScheduler() { +void MultiFileReader::ShutdownImpl() { available_thread_idx_->Close(); buffer_->Close(); waiting_reader_idx_->Close(); @@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() { } } } - // If users invoke ReInit() when scheduler is running, it will close the + // If users invoke Shutdown() when scheduler is running, it will close the // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler // to release their resource. So a check is needed before scheduler ends. for (auto& p : prefetchers_) { @@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { std::vector ins; reader->ReadNext(&ins); if (ins.empty()) { - reader->ReInit(); + reader->Shutdown(); + reader->Start(); break; } try { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7a8bb712452538b7e2a349d56a15de3284f82b39..0c523b6f176345c0407b8541c04fb8c3b27f7c60 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference); py::class_(m, "Reader", "") - .def("reset", &framework::ReaderHolder::ReInit); + .def("reset", &framework::ReaderHolder::ResetAll); using LoDTensorBlockingQueue = ::paddle::operators::reader::LoDTensorBlockingQueue;