提交 a05a948d 编写于 作者: Q Qiao Longfei

update readthread

上级 2cd25794
......@@ -41,13 +41,16 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
int thread_num = Attr<int>("thread_num");
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots");
int batch_size = Attr<int>("batch_size");
std::vector<std::string> file_list =
Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size,
thread_num, slots, file_list));
auto thread_num = Attr<int>("thread_num");
auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
auto dense_slots = Attr<std::vector<std::string>>("dense_slots");
auto batch_size = Attr<int>("batch_size");
auto file_type = Attr<std::string>("file_type");
auto file_format = Attr<std::string>("file_format");
auto file_list = Attr<std::vector<std::string>>("file_list");
out->Reset(std::make_shared<CTRReader>(
queue_holder->GetQueue(), batch_size, thread_num, file_type,
file_format, dense_slots, sparse_slots, file_list));
}
};
......@@ -58,10 +61,15 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("thread_num", "the thread num to read data");
AddAttr<int>("batch_size", "the batch size of read data");
AddAttr<std::string>("file_type", "plain or gzip").SetDefault("plain");
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read");
AddAttr<std::vector<std::string>>(
"slots", "the slots that should be extract from file");
"dense_slots", "the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"sparse_slots", "the sparse slots id that should be extract from file");
AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp.
......
......@@ -141,40 +141,42 @@ class MultiFileReader : public Reader {
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "monitor thread in";
VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true;
while (reader_thread_is_running) {
VLOG(30) << "reader_thread_is_running";
VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) {
VLOG(30) << "reader is running!";
VLOG(3) << "reader is running!";
reader_thread_is_running = true;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(30) << "all reader thread is stopped, push empty data into queue";
VLOG(3) << "all reader thread is stopped, push empty data into queue";
queue->Push({});
VLOG(30) << "monitor thread exited";
VLOG(3) << "monitor thread exited";
}
void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size,
const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(30) << "[" << thread_id << "]"
<< " file " << file;
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(30) << "set status to running";
VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < slots.size(); ++i) {
slot_to_index[slots[i]] = i;
for (size_t i = 0; i < sparse_slots.size(); ++i) {
slot_to_index[sparse_slots[i]] = i;
}
std::string line;
......@@ -182,11 +184,18 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label;
MultiFileReader<GzipReader> reader(file_list);
std::unique_ptr<Reader> reader;
if (file_type == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (file_type == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", file_type);
}
VLOG(30) << "reader inited";
VLOG(3) << "reader inited";
while (reader.HasNext()) {
while (reader->HasNext()) {
batch_data.clear();
batch_data.reserve(batch_size);
......@@ -195,8 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
// read batch_size data
for (int i = 0; i < batch_size; ++i) {
if (reader.HasNext()) {
reader.NextLine(&line);
if (reader->HasNext()) {
reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
int64_t label;
parse_line(line, slot_to_index, &label, &slot_to_data);
......@@ -209,8 +218,8 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots
for (auto& slot : slots) {
// first insert tensor for each sparse_slots
for (auto& slot : sparse_slots) {
std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign;
......@@ -242,11 +251,11 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas.push_back(label_tensor);
queue->Push(lod_datas);
VLOG(40) << "push one data, queue_size=" << queue->Size();
VLOG(4) << "push one data, queue_size=" << queue->Size();
}
(*thread_status)[thread_id] = Stopped;
VLOG(30) << "set status to stopped, thread " << thread_id << " exited";
VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
}
} // namespace reader
......
......@@ -36,7 +36,9 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped };
void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size,
const std::string& file_type, const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue);
......@@ -47,11 +49,18 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader {
public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num,
const std::vector<std::string>& slots,
const std::vector<std::string>& file_list)
: batch_size_(batch_size), slots_(slots), file_list_(file_list) {
CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, int thread_num, const std::string& file_type,
const std::string& file_format,
const std::vector<std::string>& dense_slots,
const std::vector<std::string>& sparse_slots,
const std::vector<std::string>& file_list)
: batch_size_(batch_size),
file_type_(file_type),
file_format_(file_format),
dense_slots_(dense_slots),
sparse_slots_(sparse_slots),
file_list_(file_list) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
......@@ -97,7 +106,8 @@ class CTRReader : public framework::FileReader {
VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
std::bind(&ReadThread, file_groups_[thread_id], file_type_,
file_format_, dense_slots_, sparse_slots_, batch_size_,
thread_id, &read_thread_status_, queue_)));
}
monitor_thread_.reset(new std::thread(
......@@ -119,7 +129,10 @@ class CTRReader : public framework::FileReader {
private:
size_t thread_num_;
const int batch_size_;
const std::vector<std::string> slots_;
const std::string file_type_;
const std::string file_format_;
const std::vector<std::string> dense_slots_;
const std::vector<std::string> sparse_slots_;
const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_;
......
......@@ -132,24 +132,27 @@ TEST(CTR_READER, read_data) {
int batch_size = 3;
int thread_num = 1;
std::vector<std::string> slots = {"6002", "6003"};
std::vector<std::string> sparse_slots = {"6002", "6003"};
std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) {
file_list.push_back(gz_file_name);
}
CTRReader reader(queue, batch_size, thread_num, slots, file_list);
CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {},
sparse_slots, file_list);
reader.Start();
size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
data_slot_6003, batch_num, batch_size, queue, &reader);
check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown();
reader.Start();
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
data_slot_6003, batch_num, batch_size, queue, &reader);
check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册