提交 20f181cd 编写于 作者: Q Qiao Longfei

init ctr_reader

上级 d26e4507
......@@ -16,7 +16,7 @@ function(reader_library TARGET_NAME)
endfunction()
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool)
cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost)
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
......
......@@ -41,7 +41,13 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue()));
int thread_num = Attr<int>("thread_num");
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots");
int batch_size = Attr<int>("batch_size");
std::vector<std::string> file_list =
Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size,
thread_num, slots, file_list));
}
};
......@@ -50,6 +56,12 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
void Apply() override {
AddInput("blocking_queue",
"Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("thread_num", "the thread num to read data");
AddAttr<int>("batch_size", "the batch size of read data");
AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read");
AddAttr<std::vector<std::string>>(
"slots", "the slots that should be extract from file");
AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp.
......
......@@ -14,8 +14,137 @@
#include "paddle/fluid/operators/reader/ctr_reader.h"
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <algorithm>
#include <random>
#include <boost/iostreams/copy.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include <boost/iostreams/filtering_streambuf.hpp>
namespace paddle {
namespace operators {
namespace reader {} // namespace reader
namespace reader {
static inline void string_split(const std::string& s, const char delimiter,
std::vector<std::string>* output) {
size_t start = 0;
size_t end = s.find_first_of(delimiter);
while (end <= std::string::npos) {
output->emplace_back(s.substr(start, end - start));
if (end == std::string::npos) {
break;
}
start = end + 1;
end = s.find_first_of(delimiter, start);
}
}
static inline void parse_line(
const std::string& line, const std::vector<std::string>& slots,
int64_t* label,
std::unordered_map<std::string, std::vector<int64_t>>* slots_to_data) {
std::vector<std::string> ret;
string_split(line, ' ', &ret);
*label = std::stoi(ret[2]) > 0;
for (size_t i = 3; i < ret.size(); ++i) {
const std::string& item = ret[i];
std::vector<std::string> slot_and_feasign;
string_split(item, ':', &slot_and_feasign);
if (slot_and_feasign.size() == 2) {
const std::string& slot = slot_and_feasign[1];
int64_t feasign = std::strtoll(slot_and_feasign[0].c_str(), NULL, 10);
(*slots_to_data)[slot_and_feasign[1]].push_back(feasign);
}
}
}
// class Reader {
// public:
// virtual ~Reader() {}
// virtual bool HasNext() = 0;
// virtual void NextLine(std::string& line) = 0;
//};
class GzipReader {
public:
explicit GzipReader(const std::string& file_name) : instream_(&inbuf_) {
file_ = std::ifstream(file_name, std::ios_base::in | std::ios_base::binary);
inbuf_.push(boost::iostreams::gzip_decompressor());
inbuf_.push(file_);
// Convert streambuf to istream
}
~GzipReader() { file_.close(); }
bool HasNext() { return instream_.peek() != EOF; }
void NextLine(std::string& line) { std::getline(instream_, line); } // NOLINT
private:
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf_;
std::ifstream file_;
std::istream instream_;
};
class MultiGzipReader {
public:
explicit MultiGzipReader(const std::vector<std::string>& file_list) {
for (auto& file : file_list) {
readers_.emplace_back(std::make_shared<GzipReader>(file));
}
}
bool HasNext() {
if (current_reader_index_ >= readers_.size()) {
return false;
}
if (!readers_[current_reader_index_]->HasNext()) {
current_reader_index_++;
return HasNext();
}
return true;
}
void NextLine(std::string& line) { // NOLINT
readers_[current_reader_index_]->NextLine(line);
}
private:
std::vector<std::shared_ptr<GzipReader>> readers_;
size_t current_reader_index_ = 0;
};
// void CTRReader::ReadThread(
// const std::vector<std::string> &file_list,
// const std::vector<std::string>& slots,
// int batch_size,
// std::shared_ptr<LoDTensorBlockingQueue>& queue) {}
void CTRReader::ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots,
int batch_size,
std::shared_ptr<LoDTensorBlockingQueue>* queue) {
std::string line;
// read all files
std::vector<std::string> all_lines;
MultiGzipReader reader(file_list);
for (int j = 0; j < all_lines.size(); ++j) {
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
int64_t label;
parse_line(all_lines[j], slots, &label, &slots_to_data);
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
......@@ -14,8 +14,20 @@
#pragma once
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <boost/iostreams/copy.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include <boost/iostreams/filtering_streambuf.hpp>
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace paddle {
......@@ -24,26 +36,50 @@ namespace reader {
class CTRReader : public framework::FileReader {
public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num,
const std::vector<std::string>& slots,
const std::vector<std::string>& file_list)
: framework::FileReader() {
thread_num_ = thread_num;
batch_size_ = batch_size;
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
slots_ = slots;
file_list_ = file_list;
}
~CTRReader() { queue_->Close(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
~CTRReader() { queue_->Close(); }
void Shutdown() override { queue_->Close(); }
void Start() override { queue_->ReOpen(); }
void Start() override {
queue_->ReOpen();
for (int i = 0; i < thread_num_; i++) {
read_threads_.emplace_back(
new std::thread(std::bind(&CTRReader::ReadThread, this, file_list_,
slots_, batch_size_, queue_)));
}
}
private:
void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size,
std::shared_ptr<LoDTensorBlockingQueue>* queue);
private:
std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_;
int thread_num_;
int batch_size_;
std::vector<std::string> slots_;
std::vector<std::string> file_list_;
};
} // namespace reader
......
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder)
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder boost)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc)
if(NOT WIN32)
list(APPEND PYBIND_DEPS parallel_executor profiler)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册