提交 d7c8ebac 编写于 作者: Q Qiao Longfei

add datadesc

上级 a05a948d
...@@ -43,14 +43,16 @@ class CreateCTRReaderOp : public framework::OperatorBase { ...@@ -43,14 +43,16 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto thread_num = Attr<int>("thread_num"); auto thread_num = Attr<int>("thread_num");
auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots"); auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
auto dense_slots = Attr<std::vector<std::string>>("dense_slots"); auto dense_slot_index = Attr<std::vector<int>>("dense_slot_index");
auto sparse_slot_index = Attr<std::vector<int>>("sparse_slot_index");
auto batch_size = Attr<int>("batch_size"); auto batch_size = Attr<int>("batch_size");
auto file_type = Attr<std::string>("file_type"); auto file_type = Attr<std::string>("file_type");
auto file_format = Attr<std::string>("file_format"); auto file_format = Attr<std::string>("file_format");
auto file_list = Attr<std::vector<std::string>>("file_list"); auto file_list = Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>( DataDesc data_desc(batch_size, file_list, file_type, file_format,
queue_holder->GetQueue(), batch_size, thread_num, file_type, dense_slot_index, sparse_slot_index, sparse_slots);
file_format, dense_slots, sparse_slots, file_list)); out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc));
} }
}; };
...@@ -65,11 +67,18 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { ...@@ -65,11 +67,18 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv"); AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list", AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read"); "The list of files that need to read");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>(
"dense_slots", "the sparse slots id that should be extract from file") "dense_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>(
"sparse_slots", "the sparse slots id that should be extract from file"); "dense_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots",
"the sparse slots id that should be "
"extract from file, used when file "
"format is svm");
AddComment(R"DOC( AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp. Create CTRReader to support read ctr data with cpp.
......
...@@ -73,6 +73,21 @@ static inline void parse_line( ...@@ -73,6 +73,21 @@ static inline void parse_line(
} }
} }
// label slot1:fea_sign slot2:fea_sign slot1:fea_sign
static inline void parse_svm_line(const std::string& line) {}
// label,dense_fea,dense_fea,sparse_fea,sparse_fea
static inline void parse_csv_line(const std::string& line,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
int64_t* label,
std::vector<float>* dense_datas,
std::vector<int64_t>* sparse_datas) {
std::vector<std::string> ret;
string_split(line, ',', &ret);
*label = std::stoi(ret[2]) > 0;
}
class Reader { class Reader {
public: public:
virtual ~Reader() {} virtual ~Reader() {}
...@@ -160,10 +175,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -160,10 +175,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
} }
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::string& file_type, const std::string& file_format, const DataDesc& data_desc, int thread_id,
const std::vector<std::string>& dense_slots, std::vector<ReaderThreadStatus>* thread_status,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]" VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id; << " reader thread start! thread_id = " << thread_id;
...@@ -175,8 +188,8 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -175,8 +188,8 @@ void ReadThread(const std::vector<std::string>& file_list,
VLOG(3) << "set status to running"; VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index; std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < sparse_slots.size(); ++i) { for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[sparse_slots[i]] = i; slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
} }
std::string line; std::string line;
...@@ -185,25 +198,25 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -185,25 +198,25 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
std::unique_ptr<Reader> reader; std::unique_ptr<Reader> reader;
if (file_type == "gzip") { if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list)); reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (file_type == "plain") { } else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list)); reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else { } else {
PADDLE_THROW("do not support file format %s", file_type); PADDLE_THROW("do not support file format %s", data_desc.file_type_);
} }
VLOG(3) << "reader inited"; VLOG(3) << "reader inited";
while (reader->HasNext()) { while (reader->HasNext()) {
batch_data.clear(); batch_data.clear();
batch_data.reserve(batch_size); batch_data.reserve(data_desc.batch_size_);
batch_label.clear(); batch_label.clear();
batch_label.reserve(batch_size); batch_label.reserve(data_desc.batch_size_);
// read batch_size data // read batch_size data
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader->HasNext()) { if (reader->HasNext()) {
reader->NextLine(&line); reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data; std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
...@@ -219,7 +232,7 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -219,7 +232,7 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas; std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each sparse_slots // first insert tensor for each sparse_slots
for (auto& slot : sparse_slots) { for (auto& slot : data_desc.sparse_slot_ids_) {
std::vector<size_t> lod_data{0}; std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign; std::vector<int64_t> batch_feasign;
......
...@@ -35,11 +35,34 @@ namespace reader { ...@@ -35,11 +35,34 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped }; enum ReaderThreadStatus { Running, Stopped };
void ReadThread(const std::vector<std::string>& file_list, struct DataDesc {
DataDesc(int batch_size, const std::vector<std::string>& file_names,
const std::string& file_type, const std::string& file_format, const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots, const std::vector<int>& dense_slot_index,
const std::vector<std::string>& sparse_slots, int batch_size, const std::vector<int>& sparse_slot_index,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, const std::vector<std::string>& sparse_slot_ids)
: batch_size_(batch_size),
file_names_(file_names),
file_type_(file_type),
file_format_(file_format),
dense_slot_index_(dense_slot_index),
sparse_slot_index_(sparse_slot_index),
sparse_slot_ids_(sparse_slot_ids) {}
const int batch_size_;
const std::vector<std::string> file_names_;
const std::string file_type_; // gzip or plain
const std::string file_format_; // csv or svm
// used for csv data format
const std::vector<int> dense_slot_index_;
const std::vector<int> sparse_slot_index_;
// used for svm data format
const std::vector<std::string> sparse_slot_ids_;
};
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped, // monitor all running thread, if they are all stopped,
...@@ -50,22 +73,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -50,22 +73,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num, const std::string& file_type, int thread_num, const DataDesc& data_desc)
const std::string& file_format, : data_desc_(data_desc) {
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
const std::vector<std::string>& file_list)
: batch_size_(batch_size),
file_type_(file_type),
file_format_(file_format),
dense_slots_(dense_slots),
sparse_slots_(sparse_slots),
file_list_(file_list) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
thread_num_ = "file list should not be empty");
file_list_.size() > thread_num ? thread_num : file_list_.size(); thread_num_ = data_desc_.file_names_.size() > thread_num
? thread_num
: data_desc_.file_names_.size();
queue_ = queue; queue_ = queue;
SplitFiles(); SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) { for (size_t i = 0; i < thread_num_; ++i) {
...@@ -106,9 +122,8 @@ class CTRReader : public framework::FileReader { ...@@ -106,9 +122,8 @@ class CTRReader : public framework::FileReader {
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) { for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread( read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], file_type_, std::bind(&ReadThread, file_groups_[thread_id], data_desc_, thread_id,
file_format_, dense_slots_, sparse_slots_, batch_size_, &read_thread_status_, queue_)));
thread_id, &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread( monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_))); std::bind(&MonitorThread, &read_thread_status_, queue_)));
...@@ -118,8 +133,8 @@ class CTRReader : public framework::FileReader { ...@@ -118,8 +133,8 @@ class CTRReader : public framework::FileReader {
private: private:
void SplitFiles() { void SplitFiles() {
file_groups_.resize(thread_num_); file_groups_.resize(thread_num_);
for (size_t i = 0; i < file_list_.size(); ++i) { for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
auto& file_name = file_list_[i]; auto& file_name = data_desc_.file_names_[i];
std::ifstream f(file_name.c_str()); std::ifstream f(file_name.c_str());
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
file_groups_[i % thread_num_].push_back(file_name); file_groups_[i % thread_num_].push_back(file_name);
...@@ -128,12 +143,7 @@ class CTRReader : public framework::FileReader { ...@@ -128,12 +143,7 @@ class CTRReader : public framework::FileReader {
private: private:
size_t thread_num_; size_t thread_num_;
const int batch_size_; const DataDesc data_desc_;
const std::string file_type_;
const std::string file_format_;
const std::vector<std::string> dense_slots_;
const std::vector<std::string> sparse_slots_;
const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_; std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_; std::unique_ptr<std::thread> monitor_thread_;
......
...@@ -36,6 +36,7 @@ using paddle::framework::LoD; ...@@ -36,6 +36,7 @@ using paddle::framework::LoD;
using paddle::framework::DDim; using paddle::framework::DDim;
using paddle::platform::CPUPlace; using paddle::platform::CPUPlace;
using paddle::framework::make_ddim; using paddle::framework::make_ddim;
using paddle::operators::reader::DataDesc;
static void generatedata(const std::vector<std::string>& data, static void generatedata(const std::vector<std::string>& data,
const std::string& file_name) { const std::string& file_name) {
...@@ -138,8 +139,10 @@ TEST(CTR_READER, read_data) { ...@@ -138,8 +139,10 @@ TEST(CTR_READER, read_data) {
file_list.push_back(gz_file_name); file_list.push_back(gz_file_name);
} }
CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {}, DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {},
sparse_slots, file_list); sparse_slots);
CTRReader reader(queue, thread_num, data_desc);
reader.Start(); reader.Start();
size_t batch_num = size_t batch_num =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册