提交 d7c8ebac 编写于 作者: Q Qiao Longfei

add datadesc

上级 a05a948d
......@@ -43,14 +43,16 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto thread_num = Attr<int>("thread_num");
auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
auto dense_slots = Attr<std::vector<std::string>>("dense_slots");
auto dense_slot_index = Attr<std::vector<int>>("dense_slot_index");
auto sparse_slot_index = Attr<std::vector<int>>("sparse_slot_index");
auto batch_size = Attr<int>("batch_size");
auto file_type = Attr<std::string>("file_type");
auto file_format = Attr<std::string>("file_format");
auto file_list = Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>(
queue_holder->GetQueue(), batch_size, thread_num, file_type,
file_format, dense_slots, sparse_slots, file_list));
DataDesc data_desc(batch_size, file_list, file_type, file_format,
dense_slot_index, sparse_slot_index, sparse_slots);
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc));
}
};
......@@ -65,11 +67,18 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read");
AddAttr<std::vector<std::string>>(
"dense_slots", "the sparse slots id that should be extract from file")
AddAttr<std::vector<int>>(
"dense_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"sparse_slots", "the sparse slots id that should be extract from file");
AddAttr<std::vector<int>>(
"dense_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots",
"the sparse slots id that should be "
"extract from file, used when file "
"format is svm");
AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp.
......
......@@ -73,6 +73,21 @@ static inline void parse_line(
}
}
// label slot1:fea_sign slot2:fea_sign slot1:fea_sign
static inline void parse_svm_line(const std::string& line) {}
// label,dense_fea,dense_fea,sparse_fea,sparse_fea
static inline void parse_csv_line(const std::string& line,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
int64_t* label,
std::vector<float>* dense_datas,
std::vector<int64_t>* sparse_datas) {
std::vector<std::string> ret;
string_split(line, ',', &ret);
*label = std::stoi(ret[2]) > 0;
}
class Reader {
public:
virtual ~Reader() {}
......@@ -160,10 +175,8 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
}
void ReadThread(const std::vector<std::string>& file_list,
const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
......@@ -175,8 +188,8 @@ void ReadThread(const std::vector<std::string>& file_list,
VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < sparse_slots.size(); ++i) {
slot_to_index[sparse_slots[i]] = i;
for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
}
std::string line;
......@@ -185,25 +198,25 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<int64_t> batch_label;
std::unique_ptr<Reader> reader;
if (file_type == "gzip") {
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (file_type == "plain") {
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", file_type);
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
while (reader->HasNext()) {
batch_data.clear();
batch_data.reserve(batch_size);
batch_data.reserve(data_desc.batch_size_);
batch_label.clear();
batch_label.reserve(batch_size);
batch_label.reserve(data_desc.batch_size_);
// read batch_size data
for (int i = 0; i < batch_size; ++i) {
for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader->HasNext()) {
reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
......@@ -219,7 +232,7 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each sparse_slots
for (auto& slot : sparse_slots) {
for (auto& slot : data_desc.sparse_slot_ids_) {
std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign;
......
......@@ -35,11 +35,34 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped };
void ReadThread(const std::vector<std::string>& file_list,
struct DataDesc {
DataDesc(int batch_size, const std::vector<std::string>& file_names,
const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
const std::vector<int>& dense_slot_index,
const std::vector<int>& sparse_slot_index,
const std::vector<std::string>& sparse_slot_ids)
: batch_size_(batch_size),
file_names_(file_names),
file_type_(file_type),
file_format_(file_format),
dense_slot_index_(dense_slot_index),
sparse_slot_index_(sparse_slot_index),
sparse_slot_ids_(sparse_slot_ids) {}
const int batch_size_;
const std::vector<std::string> file_names_;
const std::string file_type_; // gzip or plain
const std::string file_format_; // csv or svm
// used for csv data format
const std::vector<int> dense_slot_index_;
const std::vector<int> sparse_slot_index_;
// used for svm data format
const std::vector<std::string> sparse_slot_ids_;
};
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped,
......@@ -50,22 +73,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader {
public:
CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num, const std::string& file_type,
const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
const std::vector<std::string>& file_list)
: batch_size_(batch_size),
file_type_(file_type),
file_format_(file_format),
dense_slots_(dense_slots),
sparse_slots_(sparse_slots),
file_list_(file_list) {
int thread_num, const DataDesc& data_desc)
: data_desc_(data_desc) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
thread_num_ =
file_list_.size() > thread_num ? thread_num : file_list_.size();
PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
"file list should not be empty");
thread_num_ = data_desc_.file_names_.size() > thread_num
? thread_num
: data_desc_.file_names_.size();
queue_ = queue;
SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) {
......@@ -106,9 +122,8 @@ class CTRReader : public framework::FileReader {
VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], file_type_,
file_format_, dense_slots_, sparse_slots_, batch_size_,
thread_id, &read_thread_status_, queue_)));
std::bind(&ReadThread, file_groups_[thread_id], data_desc_, thread_id,
&read_thread_status_, queue_)));
}
monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_)));
......@@ -118,8 +133,8 @@ class CTRReader : public framework::FileReader {
private:
void SplitFiles() {
file_groups_.resize(thread_num_);
for (size_t i = 0; i < file_list_.size(); ++i) {
auto& file_name = file_list_[i];
for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
auto& file_name = data_desc_.file_names_[i];
std::ifstream f(file_name.c_str());
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
file_groups_[i % thread_num_].push_back(file_name);
......@@ -128,12 +143,7 @@ class CTRReader : public framework::FileReader {
private:
size_t thread_num_;
const int batch_size_;
const std::string file_type_;
const std::string file_format_;
const std::vector<std::string> dense_slots_;
const std::vector<std::string> sparse_slots_;
const std::vector<std::string> file_list_;
const DataDesc data_desc_;
std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_;
......
......@@ -36,6 +36,7 @@ using paddle::framework::LoD;
using paddle::framework::DDim;
using paddle::platform::CPUPlace;
using paddle::framework::make_ddim;
using paddle::operators::reader::DataDesc;
static void generatedata(const std::vector<std::string>& data,
const std::string& file_name) {
......@@ -138,8 +139,10 @@ TEST(CTR_READER, read_data) {
file_list.push_back(gz_file_name);
}
CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {},
sparse_slots, file_list);
DataDesc data_desc(batch_size, file_list, "gzip", "plain", {}, {},
sparse_slots);
CTRReader reader(queue, thread_num, data_desc);
reader.Start();
size_t batch_num =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册