提交 4051fb36 编写于 作者: Q Qiao Longfei

add monitor thread

上级 e6778337
...@@ -123,6 +123,26 @@ class MultiGzipReader : public Reader { ...@@ -123,6 +123,26 @@ class MultiGzipReader : public Reader {
size_t current_reader_index_ = 0; size_t current_reader_index_ = 0;
}; };
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true;
while (reader_thread_is_running) {
VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) {
VLOG(3) << "reader is running!";
reader_thread_is_running = true;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "all reader thread is stopped, push empty data into queue";
queue->Push({});
VLOG(3) << "monitor thread exited";
}
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const std::vector<std::string>& slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, int thread_id, std::vector<ReaderThreadStatus>* thread_status,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sys/time.h> #include <sys/time.h>
#include <chrono> // NOLINT
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
...@@ -39,6 +40,11 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -39,6 +40,11 @@ void ReadThread(const std::vector<std::string>& file_list,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue);
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
...@@ -58,7 +64,7 @@ class CTRReader : public framework::FileReader { ...@@ -58,7 +64,7 @@ class CTRReader : public framework::FileReader {
} }
} }
~CTRReader() { Shutdown(); } ~CTRReader() {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
...@@ -68,12 +74,19 @@ class CTRReader : public framework::FileReader { ...@@ -68,12 +74,19 @@ class CTRReader : public framework::FileReader {
void Shutdown() override { void Shutdown() override {
VLOG(3) << "Shutdown reader"; VLOG(3) << "Shutdown reader";
if (status_ == ReaderStatus::kStopped) {
return;
}
// shutdown should stop all the reader thread // shutdown should stop all the reader thread
for (auto& read_thread : read_threads_) { for (auto& read_thread : read_threads_) {
read_thread->join(); read_thread->join();
} }
monitor_thread_->join();
read_threads_.clear(); read_threads_.clear();
monitor_thread_.reset(nullptr);
queue_->Close(); queue_->Close();
status_ = ReaderStatus::kStopped;
} }
void Start() override { void Start() override {
...@@ -87,6 +100,9 @@ class CTRReader : public framework::FileReader { ...@@ -87,6 +100,9 @@ class CTRReader : public framework::FileReader {
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
thread_id, &read_thread_status_, queue_))); thread_id, &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_)));
status_ = ReaderStatus::kRunning;
} }
private: private:
...@@ -107,6 +123,7 @@ class CTRReader : public framework::FileReader { ...@@ -107,6 +123,7 @@ class CTRReader : public framework::FileReader {
const std::vector<std::string> file_list_; const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_; std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_;
std::vector<ReaderThreadStatus> read_thread_status_; std::vector<ReaderThreadStatus> read_thread_status_;
std::vector<std::vector<std::string>> file_groups_; std::vector<std::vector<std::string>> file_groups_;
}; };
......
...@@ -107,8 +107,8 @@ TEST(CTR_READER, read_data) { ...@@ -107,8 +107,8 @@ TEST(CTR_READER, read_data) {
size_t batch_num = size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num; std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
for (size_t i = 0; i < batch_num; ++i) {
std::vector<LoDTensor> out; std::vector<LoDTensor> out;
for (size_t i = 0; i < batch_num; ++i) {
reader.ReadNext(&out); reader.ReadNext(&out);
ASSERT_EQ(out.size(), slots.size() + 1); ASSERT_EQ(out.size(), slots.size() + 1);
auto& label_tensor = out.back(); auto& label_tensor = out.back();
...@@ -126,5 +126,12 @@ TEST(CTR_READER, read_data) { ...@@ -126,5 +126,12 @@ TEST(CTR_READER, read_data) {
tensor_6002.dims()[1] * sizeof(int64_t)), tensor_6002.dims()[1] * sizeof(int64_t)),
0); 0);
} }
reader.ReadNext(&out);
ASSERT_EQ(out.size(), 0);
ASSERT_EQ(queue->Size(), 0); ASSERT_EQ(queue->Size(), 0);
reader.Shutdown();
reader.Start();
reader.Shutdown();
ASSERT_EQ(queue->Size(), 5);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册