提交 fbd6f501 编写于 作者: Q Qiao Longfei

add ReadSvmData

上级 d7c8ebac
...@@ -78,14 +78,18 @@ static inline void parse_svm_line(const std::string& line) {} ...@@ -78,14 +78,18 @@ static inline void parse_svm_line(const std::string& line) {}
// label,dense_fea,dense_fea,sparse_fea,sparse_fea // label,dense_fea,dense_fea,sparse_fea,sparse_fea
static inline void parse_csv_line(const std::string& line, static inline void parse_csv_line(const std::string& line,
const std::vector<std::string>& dense_slots, const DataDesc& data_desc, int64_t* label,
const std::vector<std::string>& sparse_slots,
int64_t* label,
std::vector<float>* dense_datas, std::vector<float>* dense_datas,
std::vector<int64_t>* sparse_datas) { std::vector<int64_t>* sparse_datas) {
std::vector<std::string> ret; std::vector<std::string> ret;
string_split(line, ',', &ret); string_split(line, ',', &ret);
*label = std::stoi(ret[2]) > 0; *label = std::stol(ret[2]) > 0;
for (auto& idx : data_desc.dense_slot_index_) {
dense_datas->push_back(std::stof(ret[idx]));
}
for (auto& idx : data_desc.sparse_slot_index_) {
sparse_datas->push_back(std::stol(ret[idx]));
}
} }
class Reader { class Reader {
...@@ -174,19 +178,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -174,19 +178,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
VLOG(3) << "monitor thread exited"; VLOG(3) << "monitor thread exited";
} }
void ReadThread(const std::vector<std::string>& file_list, void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
const DataDesc& data_desc, int thread_id, std::shared_ptr<LoDTensorBlockingQueue> queue) {
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index; std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) { for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[data_desc.sparse_slot_ids_[i]] = i; slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
...@@ -197,17 +190,6 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -197,17 +190,6 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data; std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
std::unique_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
while (reader->HasNext()) { while (reader->HasNext()) {
batch_data.clear(); batch_data.clear();
batch_data.reserve(data_desc.batch_size_); batch_data.reserve(data_desc.batch_size_);
...@@ -266,6 +248,35 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -266,6 +248,35 @@ void ReadThread(const std::vector<std::string>& file_list,
queue->Push(lod_datas); queue->Push(lod_datas);
VLOG(4) << "push one data, queue_size=" << queue->Size(); VLOG(4) << "push one data, queue_size=" << queue->Size();
} }
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::shared_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
if (data_desc.file_format_ == "svm") {
ReadSvmData(data_desc, reader, queue);
}
(*thread_status)[thread_id] = Stopped; (*thread_status)[thread_id] = Stopped;
VLOG(3) << "set status to stopped, thread " << thread_id << " exited"; VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
......
...@@ -139,7 +139,7 @@ TEST(CTR_READER, read_data) { ...@@ -139,7 +139,7 @@ TEST(CTR_READER, read_data) {
file_list.push_back(gz_file_name); file_list.push_back(gz_file_name);
} }
DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {}, DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {},
sparse_slots); sparse_slots);
CTRReader reader(queue, thread_num, data_desc); CTRReader reader(queue, thread_num, data_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册