diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index da109733da812f61f11abb8939d05fcc5c8e16bd..97426412977766796b8bc14f12cf1feb5be06302 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -132,31 +132,36 @@ void CTRReader::ReadThread(const std::vector& file_list, std::vector batch_label; MultiGzipReader reader(file_list); - // read all files - for (int i = 0; i < batch_size; ++i) { - if (reader.HasNext()) { - reader.NextLine(&line); - std::unordered_map> slots_to_data; - int64_t label; - parse_line(line, slots, &label, &slots_to_data); - batch_data.push_back(slots_to_data); - batch_label.push_back(label); - } else { - break; + + while (reader.HasNext()) { + // read all files + for (int i = 0; i < batch_size; ++i) { + if (reader.HasNext()) { + reader.NextLine(&line); + std::unordered_map> slots_to_data; + int64_t label; + parse_line(line, slots, &label, &slots_to_data); + batch_data.push_back(slots_to_data); + batch_label.push_back(label); + } else { + break; + } } - } - std::vector lod_datas; - for (auto& slot : slots) { - for (auto& slots_to_data : batch_data) { + std::vector lod_datas; + + // first insert tensor for each slots + for (auto& slot : slots) { std::vector lod_data{0}; std::vector batch_feasign; - std::vector batch_label; - auto& feasign = slots_to_data[slot]; + for (size_t i = 0; i < batch_data.size(); ++i) { + auto& feasign = batch_data[i][slot]; + + lod_data.push_back(lod_data.back() + feasign.size()); + batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end()); + } - lod_data.push_back(lod_data.back() + feasign.size()); - batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end()); framework::LoDTensor lod_tensor; framework::LoD lod{lod_data}; lod_tensor.set_lod(lod); @@ -166,8 +171,17 @@ void CTRReader::ReadThread(const std::vector& file_list, memcpy(tensor_data, batch_feasign.data(), batch_feasign.size()); lod_datas.push_back(lod_tensor); } + + // insert label tensor + framework::LoDTensor label_tensor; + int64_t* label_tensor_data = label_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_label.size())}), + platform::CPUPlace()); + memcpy(label_tensor_data, batch_label.data(), batch_label.size()); + lod_datas.push_back(label_tensor); + + queue->Push(lod_datas); } - queue->Push(lod_datas); } } // namespace reader