提交 4051fb36 编写于 作者: Q Qiao Longfei

add monitor thread

上级 e6778337
......@@ -123,6 +123,26 @@ class MultiGzipReader : public Reader {
size_t current_reader_index_ = 0;
};
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true;
while (reader_thread_is_running) {
VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) {
VLOG(3) << "reader is running!";
reader_thread_is_running = true;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "all reader thread is stopped, push empty data into queue";
queue->Push({});
VLOG(3) << "monitor thread exited";
}
void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
......
......@@ -16,6 +16,7 @@
#include <sys/time.h>
#include <chrono> // NOLINT
#include <cstdlib>
#include <fstream>
#include <iostream>
......@@ -39,6 +40,11 @@ void ReadThread(const std::vector<std::string>& file_list,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue);
class CTRReader : public framework::FileReader {
public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
......@@ -58,7 +64,7 @@ class CTRReader : public framework::FileReader {
}
}
~CTRReader() { Shutdown(); }
~CTRReader() {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
......@@ -68,12 +74,19 @@ class CTRReader : public framework::FileReader {
void Shutdown() override {
VLOG(3) << "Shutdown reader";
if (status_ == ReaderStatus::kStopped) {
return;
}
// shutdown should stop all the reader thread
for (auto& read_thread : read_threads_) {
read_thread->join();
}
monitor_thread_->join();
read_threads_.clear();
monitor_thread_.reset(nullptr);
queue_->Close();
status_ = ReaderStatus::kStopped;
}
void Start() override {
......@@ -87,6 +100,9 @@ class CTRReader : public framework::FileReader {
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
thread_id, &read_thread_status_, queue_)));
}
monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_)));
status_ = ReaderStatus::kRunning;
}
private:
......@@ -107,6 +123,7 @@ class CTRReader : public framework::FileReader {
const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_;
std::vector<ReaderThreadStatus> read_thread_status_;
std::vector<std::vector<std::string>> file_groups_;
};
......
......@@ -107,8 +107,8 @@ TEST(CTR_READER, read_data) {
size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
for (size_t i = 0; i < batch_num; ++i) {
std::vector<LoDTensor> out;
for (size_t i = 0; i < batch_num; ++i) {
reader.ReadNext(&out);
ASSERT_EQ(out.size(), slots.size() + 1);
auto& label_tensor = out.back();
......@@ -126,5 +126,12 @@ TEST(CTR_READER, read_data) {
tensor_6002.dims()[1] * sizeof(int64_t)),
0);
}
reader.ReadNext(&out);
ASSERT_EQ(out.size(), 0);
ASSERT_EQ(queue->Size(), 0);
reader.Shutdown();
reader.Start();
reader.Shutdown();
ASSERT_EQ(queue->Size(), 5);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册