提交 a05a948d 编写于 作者: Q Qiao Longfei

update readthread

上级 2cd25794
...@@ -41,13 +41,16 @@ class CreateCTRReaderOp : public framework::OperatorBase { ...@@ -41,13 +41,16 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
int thread_num = Attr<int>("thread_num"); auto thread_num = Attr<int>("thread_num");
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots"); auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
int batch_size = Attr<int>("batch_size"); auto dense_slots = Attr<std::vector<std::string>>("dense_slots");
std::vector<std::string> file_list = auto batch_size = Attr<int>("batch_size");
Attr<std::vector<std::string>>("file_list"); auto file_type = Attr<std::string>("file_type");
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size, auto file_format = Attr<std::string>("file_format");
thread_num, slots, file_list)); auto file_list = Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>(
queue_holder->GetQueue(), batch_size, thread_num, file_type,
file_format, dense_slots, sparse_slots, file_list));
} }
}; };
...@@ -58,10 +61,15 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { ...@@ -58,10 +61,15 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("thread_num", "the thread num to read data"); AddAttr<int>("thread_num", "the thread num to read data");
AddAttr<int>("batch_size", "the batch size of read data"); AddAttr<int>("batch_size", "the batch size of read data");
AddAttr<std::string>("file_type", "plain or gzip").SetDefault("plain");
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list", AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read"); "The list of files that need to read");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"slots", "the slots that should be extract from file"); "dense_slots", "the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"sparse_slots", "the sparse slots id that should be extract from file");
AddComment(R"DOC( AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp. Create CTRReader to support read ctr data with cpp.
......
...@@ -141,40 +141,42 @@ class MultiFileReader : public Reader { ...@@ -141,40 +141,42 @@ class MultiFileReader : public Reader {
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "monitor thread in"; VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true; bool reader_thread_is_running = true;
while (reader_thread_is_running) { while (reader_thread_is_running) {
VLOG(30) << "reader_thread_is_running"; VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false; reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) { for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) { if ((*thread_status)[i] == Running) {
VLOG(30) << "reader is running!"; VLOG(3) << "reader is running!";
reader_thread_is_running = true; reader_thread_is_running = true;
} }
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(30) << "all reader thread is stopped, push empty data into queue"; VLOG(3) << "all reader thread is stopped, push empty data into queue";
queue->Push({}); queue->Push({});
VLOG(30) << "monitor thread exited"; VLOG(3) << "monitor thread exited";
} }
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "[" << thread_id << "]" VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id; << " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) { for (auto& file : file_list) {
VLOG(30) << "[" << thread_id << "]" VLOG(3) << "[" << thread_id << "]"
<< " file " << file; << " file " << file;
} }
(*thread_status)[thread_id] = Running; (*thread_status)[thread_id] = Running;
VLOG(30) << "set status to running"; VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index; std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < slots.size(); ++i) { for (size_t i = 0; i < sparse_slots.size(); ++i) {
slot_to_index[slots[i]] = i; slot_to_index[sparse_slots[i]] = i;
} }
std::string line; std::string line;
...@@ -182,11 +184,18 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -182,11 +184,18 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data; std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
MultiFileReader<GzipReader> reader(file_list); std::unique_ptr<Reader> reader;
if (file_type == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (file_type == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", file_type);
}
VLOG(30) << "reader inited"; VLOG(3) << "reader inited";
while (reader.HasNext()) { while (reader->HasNext()) {
batch_data.clear(); batch_data.clear();
batch_data.reserve(batch_size); batch_data.reserve(batch_size);
...@@ -195,8 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -195,8 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
// read batch_size data // read batch_size data
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (reader.HasNext()) { if (reader->HasNext()) {
reader.NextLine(&line); reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data; std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
int64_t label; int64_t label;
parse_line(line, slot_to_index, &label, &slot_to_data); parse_line(line, slot_to_index, &label, &slot_to_data);
...@@ -209,8 +218,8 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -209,8 +218,8 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas; std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots // first insert tensor for each sparse_slots
for (auto& slot : slots) { for (auto& slot : sparse_slots) {
std::vector<size_t> lod_data{0}; std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign; std::vector<int64_t> batch_feasign;
...@@ -242,11 +251,11 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -242,11 +251,11 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas.push_back(label_tensor); lod_datas.push_back(label_tensor);
queue->Push(lod_datas); queue->Push(lod_datas);
VLOG(40) << "push one data, queue_size=" << queue->Size(); VLOG(4) << "push one data, queue_size=" << queue->Size();
} }
(*thread_status)[thread_id] = Stopped; (*thread_status)[thread_id] = Stopped;
VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
} }
} // namespace reader } // namespace reader
......
...@@ -36,7 +36,9 @@ namespace reader { ...@@ -36,7 +36,9 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped }; enum ReaderThreadStatus { Running, Stopped };
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
...@@ -47,11 +49,18 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -47,11 +49,18 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num, int batch_size, int thread_num, const std::string& file_type,
const std::vector<std::string>& slots, const std::string& file_format,
const std::vector<std::string>& file_list) const std::vector<std::string>& dense_slots,
: batch_size_(batch_size), slots_(slots), file_list_(file_list) { const std::vector<std::string>& sparse_slots,
const std::vector<std::string>& file_list)
: batch_size_(batch_size),
file_type_(file_type),
file_format_(file_format),
dense_slots_(dense_slots),
sparse_slots_(sparse_slots),
file_list_(file_list) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
...@@ -97,7 +106,8 @@ class CTRReader : public framework::FileReader { ...@@ -97,7 +106,8 @@ class CTRReader : public framework::FileReader {
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) { for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread( read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, std::bind(&ReadThread, file_groups_[thread_id], file_type_,
file_format_, dense_slots_, sparse_slots_, batch_size_,
thread_id, &read_thread_status_, queue_))); thread_id, &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread( monitor_thread_.reset(new std::thread(
...@@ -119,7 +129,10 @@ class CTRReader : public framework::FileReader { ...@@ -119,7 +129,10 @@ class CTRReader : public framework::FileReader {
private: private:
size_t thread_num_; size_t thread_num_;
const int batch_size_; const int batch_size_;
const std::vector<std::string> slots_; const std::string file_type_;
const std::string file_format_;
const std::vector<std::string> dense_slots_;
const std::vector<std::string> sparse_slots_;
const std::vector<std::string> file_list_; const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_; std::vector<std::unique_ptr<std::thread>> read_threads_;
......
...@@ -132,24 +132,27 @@ TEST(CTR_READER, read_data) { ...@@ -132,24 +132,27 @@ TEST(CTR_READER, read_data) {
int batch_size = 3; int batch_size = 3;
int thread_num = 1; int thread_num = 1;
std::vector<std::string> slots = {"6002", "6003"}; std::vector<std::string> sparse_slots = {"6002", "6003"};
std::vector<std::string> file_list; std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) { for (int i = 0; i < thread_num; ++i) {
file_list.push_back(gz_file_name); file_list.push_back(gz_file_name);
} }
CTRReader reader(queue, batch_size, thread_num, slots, file_list); CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {},
sparse_slots, file_list);
reader.Start(); reader.Start();
size_t batch_num = size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num; std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown(); reader.Shutdown();
reader.Start(); reader.Start();
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown(); reader.Shutdown();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册