diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index d4f1da69f0a6aee0d915e053ee868f69aecd9348..341aeda4e41a533f517e47ad16a3868714775c3c 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -16,7 +16,7 @@ function(reader_library TARGET_NAME) endfunction() cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) -cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool) +cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost) reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc index e182521f9ab868c4de78d465f92008a84f5403e2..58a465d87a8c0da50e3eb80fefe32d50217f6990 100644 --- a/paddle/fluid/operators/reader/create_ctr_reader_op.cc +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -41,7 +41,13 @@ class CreateCTRReaderOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - out->Reset(std::make_shared(queue_holder->GetQueue())); + int thread_num = Attr("thread_num"); + std::vector slots = Attr>("slots"); + int batch_size = Attr("batch_size"); + std::vector file_list = + Attr>("file_list"); + out->Reset(std::make_shared(queue_holder->GetQueue(), batch_size, + thread_num, slots, file_list)); } }; @@ -50,6 +56,12 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { void Apply() override { AddInput("blocking_queue", "Name of the `LoDTensorBlockingQueueHolder` variable"); + AddAttr("thread_num", "the thread num to read data"); + AddAttr("batch_size", "the batch size of read data"); + AddAttr>("file_list", + "The list of files that need to read"); + AddAttr>( + "slots", "the slots that should be extract from file"); AddComment(R"DOC( Create CTRReader to support read ctr data with cpp. diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index bcf49fc967cab6a275e154d1b6f18034734f9200..a4197a54349eeb1582bbba2cf0c65bc8e0d20b9c 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -14,8 +14,137 @@ #include "paddle/fluid/operators/reader/ctr_reader.h" +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + namespace paddle { namespace operators { -namespace reader {} // namespace reader +namespace reader { + +static inline void string_split(const std::string& s, const char delimiter, + std::vector* output) { + size_t start = 0; + size_t end = s.find_first_of(delimiter); + + while (end <= std::string::npos) { + output->emplace_back(s.substr(start, end - start)); + if (end == std::string::npos) { + break; + } + start = end + 1; + end = s.find_first_of(delimiter, start); + } +} + +static inline void parse_line( + const std::string& line, const std::vector& slots, + int64_t* label, + std::unordered_map>* slots_to_data) { + std::vector ret; + string_split(line, ' ', &ret); + *label = std::stoi(ret[2]) > 0; + for (size_t i = 3; i < ret.size(); ++i) { + const std::string& item = ret[i]; + std::vector slot_and_feasign; + string_split(item, ':', &slot_and_feasign); + if (slot_and_feasign.size() == 2) { + const std::string& slot = slot_and_feasign[1]; + int64_t feasign = std::strtoll(slot_and_feasign[0].c_str(), NULL, 10); + (*slots_to_data)[slot_and_feasign[1]].push_back(feasign); + } + } +} + +// class Reader { +// public: +// virtual ~Reader() {} +// virtual bool HasNext() = 0; +// virtual void NextLine(std::string& line) = 0; +//}; + +class GzipReader { + public: + explicit GzipReader(const std::string& file_name) : instream_(&inbuf_) { + file_ = std::ifstream(file_name, std::ios_base::in | std::ios_base::binary); + inbuf_.push(boost::iostreams::gzip_decompressor()); + inbuf_.push(file_); + // Convert streambuf to istream + } + + ~GzipReader() { file_.close(); } + + bool HasNext() { return instream_.peek() != EOF; } + + void NextLine(std::string& line) { std::getline(instream_, line); } // NOLINT + + private: + boost::iostreams::filtering_streambuf inbuf_; + std::ifstream file_; + std::istream instream_; +}; + +class MultiGzipReader { + public: + explicit MultiGzipReader(const std::vector& file_list) { + for (auto& file : file_list) { + readers_.emplace_back(std::make_shared(file)); + } + } + + bool HasNext() { + if (current_reader_index_ >= readers_.size()) { + return false; + } + if (!readers_[current_reader_index_]->HasNext()) { + current_reader_index_++; + return HasNext(); + } + return true; + } + + void NextLine(std::string& line) { // NOLINT + readers_[current_reader_index_]->NextLine(line); + } + + private: + std::vector> readers_; + size_t current_reader_index_ = 0; +}; + +// void CTRReader::ReadThread( +// const std::vector &file_list, +// const std::vector& slots, +// int batch_size, +// std::shared_ptr& queue) {} + +void CTRReader::ReadThread(const std::vector& file_list, + const std::vector& slots, + int batch_size, + std::shared_ptr* queue) { + std::string line; + + // read all files + std::vector all_lines; + MultiGzipReader reader(file_list); + + for (int j = 0; j < all_lines.size(); ++j) { + std::unordered_map> slots_to_data; + int64_t label; + parse_line(all_lines[j], slots, &label, &slots_to_data); + } +} + +} // namespace reader } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index c3cf78e5f43aa99df42de2981a0a5ec5d96e8133..8a25993699213ae0710cff00bd469e5b1725a44f 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -14,8 +14,20 @@ #pragma once +#include +#include +#include +#include +#include +#include #include + +#include +#include +#include + #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" namespace paddle { @@ -24,26 +36,50 @@ namespace reader { class CTRReader : public framework::FileReader { public: - explicit CTRReader(const std::shared_ptr& queue) + explicit CTRReader(const std::shared_ptr& queue, + int batch_size, int thread_num, + const std::vector& slots, + const std::vector& file_list) : framework::FileReader() { + thread_num_ = thread_num; + batch_size_ = batch_size; PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); queue_ = queue; + slots_ = slots; + file_list_ = file_list; } + ~CTRReader() { queue_->Close(); } + void ReadNext(std::vector* out) override { bool success; *out = queue_->Pop(&success); if (!success) out->clear(); } - ~CTRReader() { queue_->Close(); } - void Shutdown() override { queue_->Close(); } - void Start() override { queue_->ReOpen(); } + void Start() override { + queue_->ReOpen(); + for (int i = 0; i < thread_num_; i++) { + read_threads_.emplace_back( + new std::thread(std::bind(&CTRReader::ReadThread, this, file_list_, + slots_, batch_size_, queue_))); + } + } + + private: + void ReadThread(const std::vector& file_list, + const std::vector& slots, int batch_size, + std::shared_ptr* queue); private: std::shared_ptr queue_; + std::vector> read_threads_; + int thread_num_; + int batch_size_; + std::vector slots_; + std::vector file_list_; }; } // namespace reader diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index e7f634c4a622b48e97040987836406cf73cb23b6..5ef519367427d9eb4398093e933193c6ecff64d6 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder) +set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder boost) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) if(NOT WIN32) list(APPEND PYBIND_DEPS parallel_executor profiler)