From 4051fb36b55357fb4c5587aa9436651e4db34db8 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 21 Oct 2018 21:54:47 +0800 Subject: [PATCH] add monitor thread --- paddle/fluid/operators/reader/ctr_reader.cc | 20 +++++++++++++++++++ paddle/fluid/operators/reader/ctr_reader.h | 19 +++++++++++++++++- .../fluid/operators/reader/ctr_reader_test.cc | 9 ++++++++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 0002e80a306..3156070e2c4 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -123,6 +123,26 @@ class MultiGzipReader : public Reader { size_t current_reader_index_ = 0; }; +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue) { + VLOG(3) << "monitor thread in"; + bool reader_thread_is_running = true; + while (reader_thread_is_running) { + VLOG(3) << "reader_thread_is_running"; + reader_thread_is_running = false; + for (size_t i = 0; i < (*thread_status).size(); ++i) { + if ((*thread_status)[i] == Running) { + VLOG(3) << "reader is running!"; + reader_thread_is_running = true; + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(3) << "all reader thread is stopped, push empty data into queue"; + queue->Push({}); + VLOG(3) << "monitor thread exited"; +} + void ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, int thread_id, std::vector* thread_status, diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 244a5e2e775..9b2a11bae12 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -16,6 +16,7 @@ #include +#include // NOLINT #include #include #include @@ -39,6 +40,11 @@ void ReadThread(const std::vector& file_list, int thread_id, std::vector* thread_status, std::shared_ptr queue); +// monitor all running thread, if they are all stopped, +// then push an empty data into LoDTensorBlockingQueue +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue); + class CTRReader : public framework::FileReader { public: explicit CTRReader(const std::shared_ptr& queue, @@ -58,7 +64,7 @@ class CTRReader : public framework::FileReader { } } - ~CTRReader() { Shutdown(); } + ~CTRReader() {} void ReadNext(std::vector* out) override { bool success; @@ -68,12 +74,19 @@ class CTRReader : public framework::FileReader { void Shutdown() override { VLOG(3) << "Shutdown reader"; + if (status_ == ReaderStatus::kStopped) { + return; + } // shutdown should stop all the reader thread for (auto& read_thread : read_threads_) { read_thread->join(); } + monitor_thread_->join(); + read_threads_.clear(); + monitor_thread_.reset(nullptr); queue_->Close(); + status_ = ReaderStatus::kStopped; } void Start() override { @@ -87,6 +100,9 @@ class CTRReader : public framework::FileReader { std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, thread_id, &read_thread_status_, queue_))); } + monitor_thread_.reset(new std::thread( + std::bind(&MonitorThread, &read_thread_status_, queue_))); + status_ = ReaderStatus::kRunning; } private: @@ -107,6 +123,7 @@ class CTRReader : public framework::FileReader { const std::vector file_list_; std::shared_ptr queue_; std::vector> read_threads_; + std::unique_ptr monitor_thread_; std::vector read_thread_status_; std::vector> file_groups_; }; diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index 0b8a053a86d..190182f45c5 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -107,8 +107,8 @@ TEST(CTR_READER, read_data) { size_t batch_num = std::ceil(static_cast(ctr_data.size()) / batch_size) * thread_num; + std::vector out; for (size_t i = 0; i < batch_num; ++i) { - std::vector out; reader.ReadNext(&out); ASSERT_EQ(out.size(), slots.size() + 1); auto& label_tensor = out.back(); @@ -126,5 +126,12 @@ TEST(CTR_READER, read_data) { tensor_6002.dims()[1] * sizeof(int64_t)), 0); } + reader.ReadNext(&out); + ASSERT_EQ(out.size(), 0); ASSERT_EQ(queue->Size(), 0); + reader.Shutdown(); + + reader.Start(); + reader.Shutdown(); + ASSERT_EQ(queue->Size(), 5); } -- GitLab