diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 8be9f68c9410ac9dede0a70f8137e552d3009ef8..7c83a7d62c51c0fa37cc4030e65ac3c429cb2cc6 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -52,6 +52,7 @@ static inline void parse_line( std::vector ret; string_split(line, ' ', &ret); *label = std::stoi(ret[2]) > 0; + for (size_t i = 3; i < ret.size(); ++i) { const std::string& item = ret[i]; std::vector slot_and_feasign; @@ -62,6 +63,13 @@ static inline void parse_line( (*slots_to_data)[slot_and_feasign[1]].push_back(feasign); } } + + // NOTE:: if the slot has no value, then fill [0] as it's data. + for (auto& slot : slots) { + if (slots_to_data->find(slot) == slots_to_data->end()) { + (*slots_to_data)[slot].push_back(0); + } + } } // class Reader { @@ -80,9 +88,7 @@ class GzipReader { bool HasNext() { return gzstream_.peek() != EOF; } - void NextLine(std::string* line) { // NOLINT - std::getline(gzstream_, line); - } + void NextLine(std::string* line) { std::getline(gzstream_, *line); } private: igzstream gzstream_; @@ -108,7 +114,7 @@ class MultiGzipReader { } void NextLine(std::string* line) { - readers_[current_reader_index_]->NextLine(*line); + readers_[current_reader_index_]->NextLine(line); } private: @@ -119,16 +125,49 @@ class MultiGzipReader { void CTRReader::ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, - std::shared_ptr* queue) { + std::shared_ptr queue) { std::string line; + std::vector read_data; + + std::vector>> batch_data; + std::vector batch_label; - // read all files MultiGzipReader reader(file_list); - reader.NextLine(&line); + // read all files + for (int i = 0; i < batch_size; ++i) { + if (reader.HasNext()) { + reader.NextLine(&line); + std::unordered_map> slots_to_data; + int64_t label; + parse_line(line, slots, &label, &slots_to_data); + batch_data.push_back(slots_to_data); + batch_label.push_back(label); + } else { + break; + } + } - std::unordered_map> slots_to_data; - int64_t label; - parse_line(line, slots, &label, &slots_to_data); + std::vector lod_datas; + for (auto& slot : slots) { + for (auto& slots_to_data : batch_data) { + std::vector lod_data{0}; + std::vector batch_feasign; + + auto& feasign = slots_to_data[slot]; + + lod_data.push_back(lod_data.back() + feasign.size()); + batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end()); + framework::LoDTensor lod_tensor; + framework::LoD lod{lod_data}; + lod_tensor.set_lod(lod); + int64_t* tensor_data = lod_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_feasign.size())}), + platform::CPUPlace()); + memcpy(tensor_data, batch_feasign.data(), batch_feasign.size()); + lod_datas.push_back(lod_tensor); + } + } + queue->Push(lod_datas); } } // namespace reader diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 11eb4f97864a849942088694840b98a3a808877b..41c520621e46b8a0d1a3ea47d11eff17153b6e17 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -68,7 +68,7 @@ class CTRReader : public framework::FileReader { private: void ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, - std::shared_ptr* queue); + std::shared_ptr queue); private: std::shared_ptr queue_;