From fbd6f50148bb7eaf40ced1964737b2550ab746a1 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 2 Dec 2018 14:55:35 +0800 Subject: [PATCH] add ReadSvmData --- paddle/fluid/operators/reader/ctr_reader.cc | 67 +++++++++++-------- .../fluid/operators/reader/ctr_reader_test.cc | 2 +- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 0af55b503..9834d7183 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -78,14 +78,18 @@ static inline void parse_svm_line(const std::string& line) {} // label,dense_fea,dense_fea,sparse_fea,sparse_fea static inline void parse_csv_line(const std::string& line, - const std::vector& dense_slots, - const std::vector& sparse_slots, - int64_t* label, + const DataDesc& data_desc, int64_t* label, std::vector* dense_datas, std::vector* sparse_datas) { std::vector ret; string_split(line, ',', &ret); - *label = std::stoi(ret[2]) > 0; + *label = std::stol(ret[2]) > 0; + for (auto& idx : data_desc.dense_slot_index_) { + dense_datas->push_back(std::stof(ret[idx])); + } + for (auto& idx : data_desc.sparse_slot_index_) { + sparse_datas->push_back(std::stol(ret[idx])); + } } class Reader { @@ -174,19 +178,8 @@ void MonitorThread(std::vector* thread_status, VLOG(3) << "monitor thread exited"; } -void ReadThread(const std::vector& file_list, - const DataDesc& data_desc, int thread_id, - std::vector* thread_status, - std::shared_ptr queue) { - VLOG(3) << "[" << thread_id << "]" - << " reader thread start! thread_id = " << thread_id; - for (auto& file : file_list) { - VLOG(3) << "[" << thread_id << "]" - << " file " << file; - } - (*thread_status)[thread_id] = Running; - VLOG(3) << "set status to running"; - +void ReadSvmData(const DataDesc& data_desc, std::shared_ptr reader, + std::shared_ptr queue) { std::unordered_map slot_to_index; for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) { slot_to_index[data_desc.sparse_slot_ids_[i]] = i; @@ -197,17 +190,6 @@ void ReadThread(const std::vector& file_list, std::vector>> batch_data; std::vector batch_label; - std::unique_ptr reader; - if (data_desc.file_type_ == "gzip") { - reader.reset(new MultiFileReader(file_list)); - } else if (data_desc.file_type_ == "plain") { - reader.reset(new MultiFileReader(file_list)); - } else { - PADDLE_THROW("do not support file format %s", data_desc.file_type_); - } - - VLOG(3) << "reader inited"; - while (reader->HasNext()) { batch_data.clear(); batch_data.reserve(data_desc.batch_size_); @@ -266,6 +248,35 @@ void ReadThread(const std::vector& file_list, queue->Push(lod_datas); VLOG(4) << "push one data, queue_size=" << queue->Size(); } +} + +void ReadThread(const std::vector& file_list, + const DataDesc& data_desc, int thread_id, + std::vector* thread_status, + std::shared_ptr queue) { + VLOG(3) << "[" << thread_id << "]" + << " reader thread start! thread_id = " << thread_id; + for (auto& file : file_list) { + VLOG(3) << "[" << thread_id << "]" + << " file " << file; + } + (*thread_status)[thread_id] = Running; + VLOG(3) << "set status to running"; + + std::shared_ptr reader; + if (data_desc.file_type_ == "gzip") { + reader.reset(new MultiFileReader(file_list)); + } else if (data_desc.file_type_ == "plain") { + reader.reset(new MultiFileReader(file_list)); + } else { + PADDLE_THROW("do not support file format %s", data_desc.file_type_); + } + + VLOG(3) << "reader inited"; + + if (data_desc.file_format_ == "svm") { + ReadSvmData(data_desc, reader, queue); + } (*thread_status)[thread_id] = Stopped; VLOG(3) << "set status to stopped, thread " << thread_id << " exited"; diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index a14e21bc8..dfdaae3a0 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -139,7 +139,7 @@ TEST(CTR_READER, read_data) { file_list.push_back(gz_file_name); } - DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {}, + DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {}, sparse_slots); CTRReader reader(queue, thread_num, data_desc); -- GitLab