提交 fbd6f501 编写于 作者: Q Qiao Longfei

add ReadSvmData

上级 d7c8ebac
......@@ -78,14 +78,18 @@ static inline void parse_svm_line(const std::string& line) {}
// label,dense_fea,dense_fea,sparse_fea,sparse_fea
static inline void parse_csv_line(const std::string& line,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
int64_t* label,
const DataDesc& data_desc, int64_t* label,
std::vector<float>* dense_datas,
std::vector<int64_t>* sparse_datas) {
std::vector<std::string> ret;
string_split(line, ',', &ret);
*label = std::stoi(ret[2]) > 0;
*label = std::stol(ret[2]) > 0;
for (auto& idx : data_desc.dense_slot_index_) {
dense_datas->push_back(std::stof(ret[idx]));
}
for (auto& idx : data_desc.sparse_slot_index_) {
sparse_datas->push_back(std::stol(ret[idx]));
}
}
class Reader {
......@@ -174,19 +178,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
VLOG(3) << "monitor thread exited";
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
......@@ -197,17 +190,6 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label;
std::unique_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
while (reader->HasNext()) {
batch_data.clear();
batch_data.reserve(data_desc.batch_size_);
......@@ -266,6 +248,35 @@ void ReadThread(const std::vector<std::string>& file_list,
queue->Push(lod_datas);
VLOG(4) << "push one data, queue_size=" << queue->Size();
}
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::shared_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
if (data_desc.file_format_ == "svm") {
ReadSvmData(data_desc, reader, queue);
}
(*thread_status)[thread_id] = Stopped;
VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
......
......@@ -139,7 +139,7 @@ TEST(CTR_READER, read_data) {
file_list.push_back(gz_file_name);
}
DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {},
DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {},
sparse_slots);
CTRReader reader(queue, thread_num, data_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册