ctr_reader.h 6.1 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Q
Qiao Longfei 已提交
17 18
#include <sys/time.h>

S
fix bug  
sneaxiy 已提交
19
#include <algorithm>
Q
Qiao Longfei 已提交
20
#include <chrono>  // NOLINT
Q
Qiao Longfei 已提交
21 22 23 24 25 26
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
Q
Qiao Longfei 已提交
27
#include <vector>
Q
Qiao Longfei 已提交
28

Q
Qiao Longfei 已提交
29
#include "paddle/fluid/framework/reader.h"
Q
Qiao Longfei 已提交
30
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
31 32 33 34 35 36
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"

namespace paddle {
namespace operators {
namespace reader {

Q
Qiao Longfei 已提交
37 38
enum ReaderThreadStatus { Running, Stopped };

Q
Qiao Longfei 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
struct DataDesc {
  DataDesc(int batch_size, const std::vector<std::string>& file_names,
           const std::string& file_type, const std::string& file_format,
           const std::vector<int>& dense_slot_index,
           const std::vector<int>& sparse_slot_index,
           const std::vector<std::string>& sparse_slot_ids)
      : batch_size_(batch_size),
        file_names_(file_names),
        file_type_(file_type),
        file_format_(file_format),
        dense_slot_index_(dense_slot_index),
        sparse_slot_index_(sparse_slot_index),
        sparse_slot_ids_(sparse_slot_ids) {}

  const int batch_size_;
  const std::vector<std::string> file_names_;
  const std::string file_type_;    // gzip or plain
  const std::string file_format_;  // csv or svm
  // used for csv data format
  const std::vector<int> dense_slot_index_;
  const std::vector<int> sparse_slot_index_;
  // used for svm data format
  const std::vector<std::string> sparse_slot_ids_;
};

Q
Qiao Longfei 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
  os << "data_desc:\n";
  os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
  os << "\tfile_type -> " << data_desc.file_type_ << "\n";
  os << "\tfile_format -> " << data_desc.file_format_ << "\n";
  os << "\tfile_names -> {";
  for (auto& file_name : data_desc.file_names_) {
    os << file_name << ",";
  }
  os << "}\n";
  os << "\tdense_slot_index -> {";
  for (auto& slot : data_desc.dense_slot_index_) {
    os << slot << ",";
  }
  os << "}\n";
  os << "\tsparse_slot_index_ -> {";
  for (auto& slot : data_desc.sparse_slot_index_) {
    os << slot << ",";
  }
  os << "}\n";
  os << "\tsparse_slot_ids_ -> {";
  for (auto& slot : data_desc.sparse_slot_ids_) {
    os << slot << ",";
  }
  os << "}\n";

  return os;
}

Q
Qiao Longfei 已提交
93
void ReadThread(const std::vector<std::string>& file_list,
Q
Qiao Longfei 已提交
94 95
                const DataDesc& data_desc, int thread_id,
                std::vector<ReaderThreadStatus>* thread_status,
Q
Qiao Longfei 已提交
96 97
                std::shared_ptr<LoDTensorBlockingQueue> queue);

Q
Qiao Longfei 已提交
98 99 100 101 102
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
                   std::shared_ptr<LoDTensorBlockingQueue> queue);

Q
Qiao Longfei 已提交
103 104
class CTRReader : public framework::FileReader {
 public:
Q
Qiao Longfei 已提交
105
  explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
106
                     int batch_size, size_t thread_num,
Q
Qiao Longfei 已提交
107 108
                     const std::vector<std::string>& slots,
                     const std::vector<std::string>& file_list)
Q
Qiao Longfei 已提交
109
      : batch_size_(batch_size), slots_(slots), file_list_(file_list) {
Q
Qiao Longfei 已提交
110
    PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
Q
Qiao Longfei 已提交
111
    PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
Q
Qiao Longfei 已提交
112
    PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
S
fix bug  
sneaxiy 已提交
113
    thread_num_ = std::min<size_t>(file_list_.size(), thread_num);
Q
Qiao Longfei 已提交
114
    queue_ = queue;
Q
Qiao Longfei 已提交
115
    SplitFiles();
Q
Qiao Longfei 已提交
116
    for (size_t i = 0; i < thread_num_; ++i) {
Q
Qiao Longfei 已提交
117 118
      read_thread_status_.push_back(Stopped);
    }
Q
Qiao Longfei 已提交
119 120
  }

Q
Qiao Longfei 已提交
121
  ~CTRReader() { Shutdown(); }
Q
Qiao Longfei 已提交
122

Q
Qiao Longfei 已提交
123 124 125 126 127 128
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
    bool success;
    *out = queue_->Pop(&success);
    if (!success) out->clear();
  }

Q
Qiao Longfei 已提交
129 130
  void Shutdown() override {
    VLOG(3) << "Shutdown reader";
Q
Qiao Longfei 已提交
131 132 133
    if (status_ == ReaderStatus::kStopped) {
      return;
    }
Q
Qiao Longfei 已提交
134
    // shutdown should stop all the reader thread
Q
Qiao Longfei 已提交
135 136 137
    for (auto& read_thread : read_threads_) {
      read_thread->join();
    }
Q
Qiao Longfei 已提交
138 139 140 141

    if (monitor_thread_) {
      monitor_thread_->join();
    }
Q
Qiao Longfei 已提交
142

Q
Qiao Longfei 已提交
143
    read_threads_.clear();
Q
Qiao Longfei 已提交
144
    monitor_thread_.reset(nullptr);
Q
Qiao Longfei 已提交
145
    queue_->Close();
Q
Qiao Longfei 已提交
146
    status_ = ReaderStatus::kStopped;
Q
Qiao Longfei 已提交
147
  }
Q
Qiao Longfei 已提交
148

Q
Qiao Longfei 已提交
149
  void Start() override {
Q
Qiao Longfei 已提交
150
    VLOG(3) << "Start reader";
Q
Qiao Longfei 已提交
151
    PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!");
Q
Qiao Longfei 已提交
152
    queue_->ReOpen();
Q
Qiao Longfei 已提交
153 154 155
    VLOG(3) << "reopen success";
    VLOG(3) << "thread_num " << thread_num_;
    for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
Q
Qiao Longfei 已提交
156 157 158
      read_threads_.emplace_back(new std::thread(std::bind(
          &ReadThread, file_groups_[thread_id], data_desc_,
          static_cast<int>(thread_id), &read_thread_status_, queue_)));
Q
Qiao Longfei 已提交
159
    }
Q
Qiao Longfei 已提交
160 161 162
    monitor_thread_.reset(new std::thread(
        std::bind(&MonitorThread, &read_thread_status_, queue_)));
    status_ = ReaderStatus::kRunning;
Q
Qiao Longfei 已提交
163 164 165
  }

 private:
Q
Qiao Longfei 已提交
166
  void SplitFiles() {
Q
Qiao Longfei 已提交
167
    file_groups_.resize(thread_num_);
Q
Qiao Longfei 已提交
168 169
    for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
      auto& file_name = data_desc_.file_names_[i];
Q
Qiao Longfei 已提交
170 171 172
      std::ifstream f(file_name.c_str());
      PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
      file_groups_[i % thread_num_].push_back(file_name);
Q
Qiao Longfei 已提交
173 174
    }
  }
Q
Qiao Longfei 已提交
175 176

 private:
Q
Qiao Longfei 已提交
177
  size_t thread_num_;
Q
Qiao Longfei 已提交
178
  const DataDesc data_desc_;
Q
Qiao Longfei 已提交
179
  std::shared_ptr<LoDTensorBlockingQueue> queue_;
Q
Qiao Longfei 已提交
180
  std::vector<std::unique_ptr<std::thread>> read_threads_;
Q
Qiao Longfei 已提交
181
  std::unique_ptr<std::thread> monitor_thread_;
Q
Qiao Longfei 已提交
182
  std::vector<ReaderThreadStatus> read_thread_status_;
Q
Qiao Longfei 已提交
183
  std::vector<std::vector<std::string>> file_groups_;
Q
Qiao Longfei 已提交
184 185 186 187 188
};

}  // namespace reader
}  // namespace operators
}  // namespace paddle