提交 694e8945 编写于 作者: Q Qiao Longfei

add a base class for reader

上级 d981333e
...@@ -132,6 +132,8 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list, ...@@ -132,6 +132,8 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
MultiGzipReader reader(file_list); MultiGzipReader reader(file_list);
while (reader.HasNext()) {
// read all files // read all files
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (reader.HasNext()) { if (reader.HasNext()) {
...@@ -147,16 +149,19 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list, ...@@ -147,16 +149,19 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
} }
std::vector<framework::LoDTensor> lod_datas; std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots
for (auto& slot : slots) { for (auto& slot : slots) {
for (auto& slots_to_data : batch_data) {
std::vector<size_t> lod_data{0}; std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign; std::vector<int64_t> batch_feasign;
std::vector<int64_t> batch_label;
auto& feasign = slots_to_data[slot]; for (size_t i = 0; i < batch_data.size(); ++i) {
auto& feasign = batch_data[i][slot];
lod_data.push_back(lod_data.back() + feasign.size()); lod_data.push_back(lod_data.back() + feasign.size());
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end()); batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
}
framework::LoDTensor lod_tensor; framework::LoDTensor lod_tensor;
framework::LoD lod{lod_data}; framework::LoD lod{lod_data};
lod_tensor.set_lod(lod); lod_tensor.set_lod(lod);
...@@ -166,8 +171,17 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list, ...@@ -166,8 +171,17 @@ void CTRReader::ReadThread(const std::vector<std::string>& file_list,
memcpy(tensor_data, batch_feasign.data(), batch_feasign.size()); memcpy(tensor_data, batch_feasign.data(), batch_feasign.size());
lod_datas.push_back(lod_tensor); lod_datas.push_back(lod_tensor);
} }
}
// insert label tensor
framework::LoDTensor label_tensor;
int64_t* label_tensor_data = label_tensor.mutable_data<int64_t>(
framework::make_ddim({1, static_cast<int64_t>(batch_label.size())}),
platform::CPUPlace());
memcpy(label_tensor_data, batch_label.data(), batch_label.size());
lod_datas.push_back(label_tensor);
queue->Push(lod_datas); queue->Push(lod_datas);
}
} }
} // namespace reader } // namespace reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册