From b4f0e579564d246b43388b872c34e7ef9baccfeb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 8 Jul 2018 22:30:33 +0800 Subject: [PATCH] fix errors --- paddle/fluid/framework/reader.h | 3 ++- .../operators/reader/create_batch_reader_op.cc | 1 - .../operators/reader/create_custom_reader_op.cc | 4 +--- .../reader/create_double_buffer_reader_op.cc | 4 ++-- .../reader/create_multi_pass_reader_op.cc | 4 +--- .../reader/create_random_data_generator_op.cc | 1 - .../reader/create_recordio_file_reader_op.cc | 1 - .../operators/reader/create_shuffle_reader_op.cc | 2 +- .../reader/create_threaded_reader_op.cc | 4 +--- paddle/fluid/operators/reader/open_files_op.cc | 16 ++++++++++------ 10 files changed, 18 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 6b62d11802b..91108544ac1 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -43,7 +44,7 @@ class ReaderBase { virtual void StartImpl() = 0; - std::atomic status_{kStopped}; + std::atomic status_{kRunning}; }; class DecoratedReader : public ReaderBase { diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 4d16b82e677..e5b69dabcbd 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -26,7 +26,6 @@ class BatchReader : public framework::DecoratedReader { batch_size_(batch_size), discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); - Start(); } void ReadNextImpl(std::vector* out) override; diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 334ba4cf6e8..a53dceced32 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -31,9 +31,7 @@ class CustomReader : public framework::DecoratedReader { sub_block_id_(sub_block.ID()), exe_(framework::Executor(platform::CPUPlace())), source_var_names_(source_var_names), - sink_var_names_(sink_var_names) { - Start(); - } + sink_var_names_(sink_var_names) {} void ReadNextImpl(std::vector* out) override; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index efca6fe225b..1bf6a86a5ad 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -47,12 +47,12 @@ class DoubleBufferReader : public framework::DecoratedReader { } } #endif - Start(); + StartPrefetcher(); } void ReadNextImpl(std::vector* out) override; - ~DoubleBufferReader() { Shutdown(); } + ~DoubleBufferReader() { EndPrefetcher(); } private: void ShutdownImpl() override { diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 82331cb2726..f26a470c258 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -22,9 +22,7 @@ namespace reader { class MultiPassReader : public framework::DecoratedReader { public: MultiPassReader(const std::shared_ptr& reader, int pass_num) - : DecoratedReader(reader), pass_num_(pass_num) { - Start(); - } + : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} void ReadNextImpl(std::vector* out) override { reader_->ReadNext(out); diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index c92a8b49b58..9f7e3fd2d87 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -30,7 +30,6 @@ class RandomDataGenerator : public framework::FileReader { unsigned int seed = std::random_device()(); engine_.seed(seed); dist_ = std::uniform_real_distribution(low_, high_); - Start(); } void ReadNextImpl(std::vector* out) override { diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 7a44bd14eb2..66f209b04e5 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -30,7 +30,6 @@ class RecordIOFileReader : public framework::FileReader { mutex_.reset(new std::mutex()); } LOG(INFO) << "Creating file reader" << filename; - Start(); } protected: diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 3cee9bfd643..1d3d85b9e4e 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -31,7 +31,7 @@ class ShuffleReader : public framework::DecoratedReader { std::random_device device; seed_ = device(); } - Start(); + ReloadBuffer(); } void ReadNextImpl(std::vector* out) override { diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 76b853527cd..88a2bcab8df 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -22,9 +22,7 @@ namespace reader { class ThreadedReader : public framework::DecoratedReader { public: explicit ThreadedReader(const std::shared_ptr& reader) - : DecoratedReader(reader) { - Start(); - } + : DecoratedReader(reader) {} void ReadNextImpl(std::vector* out) override { std::lock_guard lock(mutex_); diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 85127d93b20..c657ffc5359 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -31,16 +31,20 @@ class MultiFileReader : public framework::ReaderBase { readers_.emplace_back(CreateReaderByFileName(f_name)); } prefetchers_.resize(thread_num); - Start(); + StartNewScheduler(); } void ReadNextImpl(std::vector* out) override; - ~MultiFileReader() { Shutdown(); } + ~MultiFileReader() { EndScheduler(); } private: - void StartImpl() override; - void ShutdownImpl() override; + void ShutdownImpl() override { EndScheduler(); } + + void StartImpl() override { StartNewScheduler(); } + + void StartNewScheduler(); + void EndScheduler(); void ScheduleThreadFunc(); void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); @@ -59,7 +63,7 @@ void MultiFileReader::ReadNextImpl(std::vector* out) { } } -void MultiFileReader::StartImpl() { +void MultiFileReader::StartNewScheduler() { size_t thread_num = prefetchers_.size(); waiting_reader_idx_ = new reader::BlockingQueue(readers_.size()); available_thread_idx_ = new reader::BlockingQueue(thread_num); @@ -77,7 +81,7 @@ void MultiFileReader::StartImpl() { scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); } -void MultiFileReader::ShutdownImpl() { +void MultiFileReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); waiting_reader_idx_->Close(); -- GitLab