diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc index e66263fee11dfcadb2058f868930b9d9706a0939..5b9e2ba693f07ff95598ec9e2c0eae084023c242 100644 --- a/paddle/fluid/operators/reader/create_ctr_reader_op.cc +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -43,14 +43,16 @@ class CreateCTRReaderOp : public framework::OperatorBase { auto thread_num = Attr("thread_num"); auto sparse_slots = Attr>("sparse_slots"); - auto dense_slots = Attr>("dense_slots"); + auto dense_slot_index = Attr>("dense_slot_index"); + auto sparse_slot_index = Attr>("sparse_slot_index"); auto batch_size = Attr("batch_size"); auto file_type = Attr("file_type"); auto file_format = Attr("file_format"); auto file_list = Attr>("file_list"); - out->Reset(std::make_shared( - queue_holder->GetQueue(), batch_size, thread_num, file_type, - file_format, dense_slots, sparse_slots, file_list)); + DataDesc data_desc(batch_size, file_list, file_type, file_format, + dense_slot_index, sparse_slot_index, sparse_slots); + out->Reset(std::make_shared(queue_holder->GetQueue(), thread_num, + data_desc)); } }; @@ -65,11 +67,18 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { AddAttr("file_format", "svm or csv").SetDefault("csv"); AddAttr>("file_list", "The list of files that need to read"); - AddAttr>( - "dense_slots", "the sparse slots id that should be extract from file") + AddAttr>( + "dense_slot_index", + "the sparse slots id that should be extract from file") .SetDefault({}); - AddAttr>( - "sparse_slots", "the sparse slots id that should be extract from file"); + AddAttr>( + "dense_slot_index", + "the sparse slots id that should be extract from file") + .SetDefault({}); + AddAttr>("sparse_slots", + "the sparse slots id that should be " + "extract from file, used when file " + "format is svm"); AddComment(R"DOC( Create CTRReader to support read ctr data with cpp. diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 09939576568c42168772ba90883c22a4fb74c035..0af55b503e2baced6ca52489a0e60c4f5e9ba8f0 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -73,6 +73,21 @@ static inline void parse_line( } } +// label slot1:fea_sign slot2:fea_sign slot1:fea_sign +static inline void parse_svm_line(const std::string& line) {} + +// label,dense_fea,dense_fea,sparse_fea,sparse_fea +static inline void parse_csv_line(const std::string& line, + const std::vector& dense_slots, + const std::vector& sparse_slots, + int64_t* label, + std::vector* dense_datas, + std::vector* sparse_datas) { + std::vector ret; + string_split(line, ',', &ret); + *label = std::stoi(ret[2]) > 0; +} + class Reader { public: virtual ~Reader() {} @@ -160,10 +175,8 @@ void MonitorThread(std::vector* thread_status, } void ReadThread(const std::vector& file_list, - const std::string& file_type, const std::string& file_format, - const std::vector& dense_slots, - const std::vector& sparse_slots, int batch_size, - int thread_id, std::vector* thread_status, + const DataDesc& data_desc, int thread_id, + std::vector* thread_status, std::shared_ptr queue) { VLOG(3) << "[" << thread_id << "]" << " reader thread start! thread_id = " << thread_id; @@ -175,8 +188,8 @@ void ReadThread(const std::vector& file_list, VLOG(3) << "set status to running"; std::unordered_map slot_to_index; - for (size_t i = 0; i < sparse_slots.size(); ++i) { - slot_to_index[sparse_slots[i]] = i; + for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) { + slot_to_index[data_desc.sparse_slot_ids_[i]] = i; } std::string line; @@ -185,25 +198,25 @@ void ReadThread(const std::vector& file_list, std::vector batch_label; std::unique_ptr reader; - if (file_type == "gzip") { + if (data_desc.file_type_ == "gzip") { reader.reset(new MultiFileReader(file_list)); - } else if (file_type == "plain") { + } else if (data_desc.file_type_ == "plain") { reader.reset(new MultiFileReader(file_list)); } else { - PADDLE_THROW("do not support file format %s", file_type); + PADDLE_THROW("do not support file format %s", data_desc.file_type_); } VLOG(3) << "reader inited"; while (reader->HasNext()) { batch_data.clear(); - batch_data.reserve(batch_size); + batch_data.reserve(data_desc.batch_size_); batch_label.clear(); - batch_label.reserve(batch_size); + batch_label.reserve(data_desc.batch_size_); // read batch_size data - for (int i = 0; i < batch_size; ++i) { + for (int i = 0; i < data_desc.batch_size_; ++i) { if (reader->HasNext()) { reader->NextLine(&line); std::unordered_map> slot_to_data; @@ -219,7 +232,7 @@ void ReadThread(const std::vector& file_list, std::vector lod_datas; // first insert tensor for each sparse_slots - for (auto& slot : sparse_slots) { + for (auto& slot : data_desc.sparse_slot_ids_) { std::vector lod_data{0}; std::vector batch_feasign; diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 68d587bbfc4a60fa31401b0e4fd662f6e8deb4a9..1f4663e3b899e588981503f40e413aa816f5fbe0 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -35,11 +35,34 @@ namespace reader { enum ReaderThreadStatus { Running, Stopped }; +struct DataDesc { + DataDesc(int batch_size, const std::vector& file_names, + const std::string& file_type, const std::string& file_format, + const std::vector& dense_slot_index, + const std::vector& sparse_slot_index, + const std::vector& sparse_slot_ids) + : batch_size_(batch_size), + file_names_(file_names), + file_type_(file_type), + file_format_(file_format), + dense_slot_index_(dense_slot_index), + sparse_slot_index_(sparse_slot_index), + sparse_slot_ids_(sparse_slot_ids) {} + + const int batch_size_; + const std::vector file_names_; + const std::string file_type_; // gzip or plain + const std::string file_format_; // csv or svm + // used for csv data format + const std::vector dense_slot_index_; + const std::vector sparse_slot_index_; + // used for svm data format + const std::vector sparse_slot_ids_; +}; + void ReadThread(const std::vector& file_list, - const std::string& file_type, const std::string& file_format, - const std::vector& dense_slots, - const std::vector& sparse_slots, int batch_size, - int thread_id, std::vector* thread_status, + const DataDesc& data_desc, int thread_id, + std::vector* thread_status, std::shared_ptr queue); // monitor all running thread, if they are all stopped, @@ -50,22 +73,15 @@ void MonitorThread(std::vector* thread_status, class CTRReader : public framework::FileReader { public: CTRReader(const std::shared_ptr& queue, - int batch_size, int thread_num, const std::string& file_type, - const std::string& file_format, - const std::vector& dense_slots, - const std::vector& sparse_slots, - const std::vector& file_list) - : batch_size_(batch_size), - file_type_(file_type), - file_format_(file_format), - dense_slots_(dense_slots), - sparse_slots_(sparse_slots), - file_list_(file_list) { + int thread_num, const DataDesc& data_desc) + : data_desc_(data_desc) { PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); - PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); - thread_num_ = - file_list_.size() > thread_num ? thread_num : file_list_.size(); + PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0, + "file list should not be empty"); + thread_num_ = data_desc_.file_names_.size() > thread_num + ? thread_num + : data_desc_.file_names_.size(); queue_ = queue; SplitFiles(); for (size_t i = 0; i < thread_num_; ++i) { @@ -106,9 +122,8 @@ class CTRReader : public framework::FileReader { VLOG(3) << "thread_num " << thread_num_; for (int thread_id = 0; thread_id < thread_num_; thread_id++) { read_threads_.emplace_back(new std::thread( - std::bind(&ReadThread, file_groups_[thread_id], file_type_, - file_format_, dense_slots_, sparse_slots_, batch_size_, - thread_id, &read_thread_status_, queue_))); + std::bind(&ReadThread, file_groups_[thread_id], data_desc_, thread_id, + &read_thread_status_, queue_))); } monitor_thread_.reset(new std::thread( std::bind(&MonitorThread, &read_thread_status_, queue_))); @@ -118,8 +133,8 @@ class CTRReader : public framework::FileReader { private: void SplitFiles() { file_groups_.resize(thread_num_); - for (size_t i = 0; i < file_list_.size(); ++i) { - auto& file_name = file_list_[i]; + for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) { + auto& file_name = data_desc_.file_names_[i]; std::ifstream f(file_name.c_str()); PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); file_groups_[i % thread_num_].push_back(file_name); @@ -128,12 +143,7 @@ class CTRReader : public framework::FileReader { private: size_t thread_num_; - const int batch_size_; - const std::string file_type_; - const std::string file_format_; - const std::vector dense_slots_; - const std::vector sparse_slots_; - const std::vector file_list_; + const DataDesc data_desc_; std::shared_ptr queue_; std::vector> read_threads_; std::unique_ptr monitor_thread_; diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index 734bf45383c74acf1f2cd159440f03dc683403b4..a14e21bc8d212ee2c5eaab189786223fbccc96f6 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -36,6 +36,7 @@ using paddle::framework::LoD; using paddle::framework::DDim; using paddle::platform::CPUPlace; using paddle::framework::make_ddim; +using paddle::operators::reader::DataDesc; static void generatedata(const std::vector& data, const std::string& file_name) { @@ -138,8 +139,10 @@ TEST(CTR_READER, read_data) { file_list.push_back(gz_file_name); } - CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {}, - sparse_slots, file_list); + DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {}, + sparse_slots); + + CTRReader reader(queue, thread_num, data_desc); reader.Start(); size_t batch_num =