diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 9849eb6aef5f03dffe57d946f786f52b64a99927..60e8d1250df6680d0028746a873a209dd317a0c4 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -124,7 +124,10 @@ class MultiGzipReader : public Reader { void ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, std::shared_ptr queue) { + (*thread_status)[thread_id] = Running; + std::string line; std::vector>> batch_data; @@ -181,6 +184,8 @@ void ReadThread(const std::vector& file_list, queue->Push(lod_datas); } + + (*thread_status)[thread_id] = Stopped; } } // namespace reader diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index ef319c86326d747701df12275a19774cb0e768bf..1006ea96c9e3e3fe59807e8fd9b995724741395d 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -30,8 +30,11 @@ namespace paddle { namespace operators { namespace reader { +enum ReaderThreadStatus { Running, Stopped }; + void ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, std::shared_ptr queue); class CTRReader : public framework::FileReader { @@ -40,13 +43,16 @@ class CTRReader : public framework::FileReader { int batch_size, int thread_num, const std::vector& slots, const std::vector& file_list) - : thread_num_(thread_num), - batch_size_(batch_size), - slots_(slots), - file_list_(file_list) { + : batch_size_(batch_size), slots_(slots), file_list_(file_list) { PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); + thread_num_ = + file_list_.size() > thread_num_ ? thread_num_ : file_list_.size(); queue_ = queue; SplitFiles(); + for (int i = 0; i < thread_num; ++i) { + read_thread_status_.push_back(Stopped); + } } ~CTRReader() { queue_->Close(); } @@ -69,28 +75,29 @@ class CTRReader : public framework::FileReader { void Start() override { VLOG(3) << "Start reader"; queue_->ReOpen(); - for (int i = 0; i < file_groups_.size(); i++) { - read_threads_.emplace_back(new std::thread(std::bind( - &ReadThread, file_groups_[i], slots_, batch_size_, queue_))); + for (int thread_id = 0; thread_id < file_groups_.size(); thread_id++) { + read_threads_.emplace_back(new std::thread( + std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, + thread_id, &read_thread_status_, queue_))); } } private: void SplitFiles() { - file_groups_.resize(file_list_.size() > thread_num_ ? thread_num_ - : file_list_.size()); + file_groups_.resize(thread_num_); for (int i = 0; i < file_list_.size(); ++i) { file_groups_[i % thread_num_].push_back(file_list_[i]); } } private: - const int thread_num_; + int thread_num_; const int batch_size_; const std::vector slots_; const std::vector file_list_; std::shared_ptr queue_; std::vector> read_threads_; + std::vector read_thread_status_; std::vector> file_groups_; };