提交 71c2ad41 编写于 作者: Q Qiao Longfei

complete read thread

上级 0f3ece77
...@@ -52,6 +52,7 @@ static inline void parse_line( ...@@ -52,6 +52,7 @@ static inline void parse_line(
std::vector<std::string> ret; std::vector<std::string> ret;
string_split(line, ' ', &ret); string_split(line, ' ', &ret);
*label = std::stoi(ret[2]) > 0; *label = std::stoi(ret[2]) > 0;
for (size_t i = 3; i < ret.size(); ++i) { for (size_t i = 3; i < ret.size(); ++i) {
const std::string& item = ret[i]; const std::string& item = ret[i];
std::vector<std::string> slot_and_feasign; std::vector<std::string> slot_and_feasign;
...@@ -62,6 +63,13 @@ static inline void parse_line( ...@@ -62,6 +63,13 @@ static inline void parse_line(
(*slots_to_data)[slot_and_feasign[1]].push_back(feasign); (*slots_to_data)[slot_and_feasign[1]].push_back(feasign);
} }
} }
// NOTE:: if the slot has no value, then fill [0] as it's data.
for (auto& slot : slots) {
if (slots_to_data->find(slot) == slots_to_data->end()) {
(*slots_to_data)[slot].push_back(0);
}
}
} }
// class Reader { // class Reader {
...@@ -80,9 +88,7 @@ class GzipReader { ...@@ -80,9 +88,7 @@ class GzipReader {
bool HasNext() { return gzstream_.peek() != EOF; } bool HasNext() { return gzstream_.peek() != EOF; }
void NextLine(std::string* line) { // NOLINT void NextLine(std::string* line) { std::getline(gzstream_, *line); }
std::getline(gzstream_, line);
}
private: private:
igzstream gzstream_; igzstream gzstream_;
...@@ -108,7 +114,7 @@ class MultiGzipReader { ...@@ -108,7 +114,7 @@ class MultiGzipReader {
} }
void NextLine(std::string* line) { void NextLine(std::string* line) {
readers_[current_reader_index_]->NextLine(*line); readers_[current_reader_index_]->NextLine(line);
} }
private: private:
...@@ -119,16 +125,49 @@ class MultiGzipReader { ...@@ -119,16 +125,49 @@ class MultiGzipReader {
void CTRReader::ReadThread(const std::vector<std::string>& file_list, void CTRReader::ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, const std::vector<std::string>& slots,
int batch_size, int batch_size,
std::shared_ptr<LoDTensorBlockingQueue>* queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
std::string line; std::string line;
std::vector<framework::LoDTensor> read_data;
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label;
// read all files
MultiGzipReader reader(file_list); MultiGzipReader reader(file_list);
reader.NextLine(&line); // read all files
for (int i = 0; i < batch_size; ++i) {
if (reader.HasNext()) {
reader.NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
int64_t label;
parse_line(line, slots, &label, &slots_to_data);
batch_data.push_back(slots_to_data);
batch_label.push_back(label);
} else {
break;
}
}
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data; std::vector<framework::LoDTensor> lod_datas;
int64_t label; for (auto& slot : slots) {
parse_line(line, slots, &label, &slots_to_data); for (auto& slots_to_data : batch_data) {
std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign;
auto& feasign = slots_to_data[slot];
lod_data.push_back(lod_data.back() + feasign.size());
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
framework::LoDTensor lod_tensor;
framework::LoD lod{lod_data};
lod_tensor.set_lod(lod);
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_feasign.size())}),
platform::CPUPlace());
memcpy(tensor_data, batch_feasign.data(), batch_feasign.size());
lod_datas.push_back(lod_tensor);
}
}
queue->Push(lod_datas);
} }
} // namespace reader } // namespace reader
......
...@@ -68,7 +68,7 @@ class CTRReader : public framework::FileReader { ...@@ -68,7 +68,7 @@ class CTRReader : public framework::FileReader {
private: private:
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const std::vector<std::string>& slots, int batch_size,
std::shared_ptr<LoDTensorBlockingQueue>* queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
private: private:
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册