ctr_reader.h 6.0 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

Q
Qiao Longfei 已提交
17 18
#include <sys/time.h>

S
fix bug  
sneaxiy 已提交
19
#include <algorithm>
Q
Qiao Longfei 已提交
20
#include <chrono>  // NOLINT
Q
Qiao Longfei 已提交
21 22 23
#include <cstdlib>
#include <fstream>
#include <iostream>
S
sneaxiy 已提交
24
#include <memory>
Q
Qiao Longfei 已提交
25 26 27
#include <sstream>
#include <string>
#include <unordered_map>
Q
Qiao Longfei 已提交
28
#include <vector>
Q
Qiao Longfei 已提交
29

Q
Qiao Longfei 已提交
30
#include "paddle/fluid/framework/reader.h"
Q
Qiao Longfei 已提交
31
#include "paddle/fluid/framework/threadpool.h"
Q
Qiao Longfei 已提交
32 33 34 35 36 37
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"

namespace paddle {
namespace operators {
namespace reader {

Q
Qiao Longfei 已提交
38 39
enum ReaderThreadStatus { Running, Stopped };

Q
Qiao Longfei 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
struct DataDesc {
  DataDesc(int batch_size, const std::vector<std::string>& file_names,
           const std::string& file_type, const std::string& file_format,
           const std::vector<int>& dense_slot_index,
           const std::vector<int>& sparse_slot_index,
           const std::vector<std::string>& sparse_slot_ids)
      : batch_size_(batch_size),
        file_names_(file_names),
        file_type_(file_type),
        file_format_(file_format),
        dense_slot_index_(dense_slot_index),
        sparse_slot_index_(sparse_slot_index),
        sparse_slot_ids_(sparse_slot_ids) {}

  const int batch_size_;
  const std::vector<std::string> file_names_;
  const std::string file_type_;    // gzip or plain
  const std::string file_format_;  // csv or svm
  // used for csv data format
  const std::vector<int> dense_slot_index_;
  const std::vector<int> sparse_slot_index_;
  // used for svm data format
  const std::vector<std::string> sparse_slot_ids_;
};

Q
Qiao Longfei 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
  os << "data_desc:\n";
  os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
  os << "\tfile_type -> " << data_desc.file_type_ << "\n";
  os << "\tfile_format -> " << data_desc.file_format_ << "\n";
  os << "\tfile_names -> {";
  for (auto& file_name : data_desc.file_names_) {
    os << file_name << ",";
  }
  os << "}\n";
  os << "\tdense_slot_index -> {";
  for (auto& slot : data_desc.dense_slot_index_) {
    os << slot << ",";
  }
  os << "}\n";
  os << "\tsparse_slot_index_ -> {";
  for (auto& slot : data_desc.sparse_slot_index_) {
    os << slot << ",";
  }
  os << "}\n";
  os << "\tsparse_slot_ids_ -> {";
  for (auto& slot : data_desc.sparse_slot_ids_) {
    os << slot << ",";
  }
  os << "}\n";

  return os;
}

Q
Qiao Longfei 已提交
94
void ReadThread(const std::vector<std::string>& file_list,
Q
Qiao Longfei 已提交
95 96
                const DataDesc& data_desc, int thread_id,
                std::vector<ReaderThreadStatus>* thread_status,
Q
Qiao Longfei 已提交
97 98
                std::shared_ptr<LoDTensorBlockingQueue> queue);

Q
Qiao Longfei 已提交
99 100 101 102 103
// monitor all running thread, if they are all stopped,
// then push an empty data into LoDTensorBlockingQueue
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
                   std::shared_ptr<LoDTensorBlockingQueue> queue);

Q
Qiao Longfei 已提交
104 105
class CTRReader : public framework::FileReader {
 public:
Q
Qiao Longfei 已提交
106 107 108
  CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
            int thread_num, const DataDesc& data_desc)
      : data_desc_(data_desc) {
Q
Qiao Longfei 已提交
109
    PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
Q
Qiao Longfei 已提交
110
    PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
Q
Qiao Longfei 已提交
111 112 113 114
    PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
                      "file list should not be empty");

    thread_num_ = std::min<size_t>(data_desc_.file_names_.size(), thread_num);
Q
Qiao Longfei 已提交
115
    queue_ = queue;
Q
Qiao Longfei 已提交
116
    SplitFiles();
Q
Qiao Longfei 已提交
117
    for (size_t i = 0; i < thread_num_; ++i) {
Q
Qiao Longfei 已提交
118 119
      read_thread_status_.push_back(Stopped);
    }
Q
Qiao Longfei 已提交
120 121
  }

Q
Qiao Longfei 已提交
122
  ~CTRReader() { Shutdown(); }
Q
Qiao Longfei 已提交
123

Q
Qiao Longfei 已提交
124 125 126 127 128 129
  void ReadNext(std::vector<framework::LoDTensor>* out) override {
    bool success;
    *out = queue_->Pop(&success);
    if (!success) out->clear();
  }

Q
Qiao Longfei 已提交
130 131
  void Shutdown() override {
    VLOG(3) << "Shutdown reader";
Q
Qiao Longfei 已提交
132 133 134
    if (status_ == ReaderStatus::kStopped) {
      return;
    }
Q
Qiao Longfei 已提交
135
    // shutdown should stop all the reader thread
Q
Qiao Longfei 已提交
136 137 138
    for (auto& read_thread : read_threads_) {
      read_thread->join();
    }
Q
Qiao Longfei 已提交
139 140 141 142

    if (monitor_thread_) {
      monitor_thread_->join();
    }
Q
Qiao Longfei 已提交
143

Q
Qiao Longfei 已提交
144
    read_threads_.clear();
Q
Qiao Longfei 已提交
145
    monitor_thread_.reset(nullptr);
Q
Qiao Longfei 已提交
146
    queue_->Close();
Q
Qiao Longfei 已提交
147
    status_ = ReaderStatus::kStopped;
Q
Qiao Longfei 已提交
148
  }
Q
Qiao Longfei 已提交
149

Q
Qiao Longfei 已提交
150
  void Start() override {
Q
Qiao Longfei 已提交
151
    VLOG(3) << "Start reader";
Q
Qiao Longfei 已提交
152
    PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!");
Q
Qiao Longfei 已提交
153
    queue_->ReOpen();
Q
Qiao Longfei 已提交
154 155
    VLOG(3) << "reopen success";
    VLOG(3) << "thread_num " << thread_num_;
S
sneaxiy 已提交
156
    for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
Q
Qiao Longfei 已提交
157 158 159
      read_threads_.emplace_back(new std::thread(std::bind(
          &ReadThread, file_groups_[thread_id], data_desc_,
          static_cast<int>(thread_id), &read_thread_status_, queue_)));
Q
Qiao Longfei 已提交
160
    }
Q
Qiao Longfei 已提交
161 162 163
    monitor_thread_.reset(new std::thread(
        std::bind(&MonitorThread, &read_thread_status_, queue_)));
    status_ = ReaderStatus::kRunning;
Q
Qiao Longfei 已提交
164 165 166
  }

 private:
Q
Qiao Longfei 已提交
167
  void SplitFiles() {
Q
Qiao Longfei 已提交
168
    file_groups_.resize(thread_num_);
Q
Qiao Longfei 已提交
169 170
    for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
      auto& file_name = data_desc_.file_names_[i];
Q
Qiao Longfei 已提交
171 172 173
      std::ifstream f(file_name.c_str());
      PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
      file_groups_[i % thread_num_].push_back(file_name);
Q
Qiao Longfei 已提交
174 175
    }
  }
Q
Qiao Longfei 已提交
176 177

 private:
Q
Qiao Longfei 已提交
178
  size_t thread_num_;
Q
Qiao Longfei 已提交
179
  const DataDesc data_desc_;
Q
Qiao Longfei 已提交
180
  std::shared_ptr<LoDTensorBlockingQueue> queue_;
Q
Qiao Longfei 已提交
181
  std::vector<std::unique_ptr<std::thread>> read_threads_;
Q
Qiao Longfei 已提交
182
  std::unique_ptr<std::thread> monitor_thread_;
Q
Qiao Longfei 已提交
183
  std::vector<ReaderThreadStatus> read_thread_status_;
Q
Qiao Longfei 已提交
184
  std::vector<std::vector<std::string>> file_groups_;
Q
Qiao Longfei 已提交
185 186 187 188 189
};

}  // namespace reader
}  // namespace operators
}  // namespace paddle