diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc index 58a465d87a8c0da50e3eb80fefe32d50217f6990..e66263fee11dfcadb2058f868930b9d9706a0939 100644 --- a/paddle/fluid/operators/reader/create_ctr_reader_op.cc +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -41,13 +41,16 @@ class CreateCTRReaderOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - int thread_num = Attr("thread_num"); - std::vector slots = Attr>("slots"); - int batch_size = Attr("batch_size"); - std::vector file_list = - Attr>("file_list"); - out->Reset(std::make_shared(queue_holder->GetQueue(), batch_size, - thread_num, slots, file_list)); + auto thread_num = Attr("thread_num"); + auto sparse_slots = Attr>("sparse_slots"); + auto dense_slots = Attr>("dense_slots"); + auto batch_size = Attr("batch_size"); + auto file_type = Attr("file_type"); + auto file_format = Attr("file_format"); + auto file_list = Attr>("file_list"); + out->Reset(std::make_shared( + queue_holder->GetQueue(), batch_size, thread_num, file_type, + file_format, dense_slots, sparse_slots, file_list)); } }; @@ -58,10 +61,15 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { "Name of the `LoDTensorBlockingQueueHolder` variable"); AddAttr("thread_num", "the thread num to read data"); AddAttr("batch_size", "the batch size of read data"); + AddAttr("file_type", "plain or gzip").SetDefault("plain"); + AddAttr("file_format", "svm or csv").SetDefault("csv"); AddAttr>("file_list", "The list of files that need to read"); AddAttr>( - "slots", "the slots that should be extract from file"); + "dense_slots", "the sparse slots id that should be extract from file") + .SetDefault({}); + AddAttr>( + "sparse_slots", "the sparse slots id that should be extract from file"); AddComment(R"DOC( Create CTRReader to support read ctr data with cpp. diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index e2f8788a9a83c7a080492c3460d84d22da46b0f5..09939576568c42168772ba90883c22a4fb74c035 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -141,40 +141,42 @@ class MultiFileReader : public Reader { void MonitorThread(std::vector* thread_status, std::shared_ptr queue) { - VLOG(30) << "monitor thread in"; + VLOG(3) << "monitor thread in"; bool reader_thread_is_running = true; while (reader_thread_is_running) { - VLOG(30) << "reader_thread_is_running"; + VLOG(3) << "reader_thread_is_running"; reader_thread_is_running = false; for (size_t i = 0; i < (*thread_status).size(); ++i) { if ((*thread_status)[i] == Running) { - VLOG(30) << "reader is running!"; + VLOG(3) << "reader is running!"; reader_thread_is_running = true; } } std::this_thread::sleep_for(std::chrono::milliseconds(1000)); } - VLOG(30) << "all reader thread is stopped, push empty data into queue"; + VLOG(3) << "all reader thread is stopped, push empty data into queue"; queue->Push({}); - VLOG(30) << "monitor thread exited"; + VLOG(3) << "monitor thread exited"; } void ReadThread(const std::vector& file_list, - const std::vector& slots, int batch_size, + const std::string& file_type, const std::string& file_format, + const std::vector& dense_slots, + const std::vector& sparse_slots, int batch_size, int thread_id, std::vector* thread_status, std::shared_ptr queue) { - VLOG(30) << "[" << thread_id << "]" - << " reader thread start! thread_id = " << thread_id; + VLOG(3) << "[" << thread_id << "]" + << " reader thread start! thread_id = " << thread_id; for (auto& file : file_list) { - VLOG(30) << "[" << thread_id << "]" - << " file " << file; + VLOG(3) << "[" << thread_id << "]" + << " file " << file; } (*thread_status)[thread_id] = Running; - VLOG(30) << "set status to running"; + VLOG(3) << "set status to running"; std::unordered_map slot_to_index; - for (size_t i = 0; i < slots.size(); ++i) { - slot_to_index[slots[i]] = i; + for (size_t i = 0; i < sparse_slots.size(); ++i) { + slot_to_index[sparse_slots[i]] = i; } std::string line; @@ -182,11 +184,18 @@ void ReadThread(const std::vector& file_list, std::vector>> batch_data; std::vector batch_label; - MultiFileReader reader(file_list); + std::unique_ptr reader; + if (file_type == "gzip") { + reader.reset(new MultiFileReader(file_list)); + } else if (file_type == "plain") { + reader.reset(new MultiFileReader(file_list)); + } else { + PADDLE_THROW("do not support file format %s", file_type); + } - VLOG(30) << "reader inited"; + VLOG(3) << "reader inited"; - while (reader.HasNext()) { + while (reader->HasNext()) { batch_data.clear(); batch_data.reserve(batch_size); @@ -195,8 +204,8 @@ void ReadThread(const std::vector& file_list, // read batch_size data for (int i = 0; i < batch_size; ++i) { - if (reader.HasNext()) { - reader.NextLine(&line); + if (reader->HasNext()) { + reader->NextLine(&line); std::unordered_map> slot_to_data; int64_t label; parse_line(line, slot_to_index, &label, &slot_to_data); @@ -209,8 +218,8 @@ void ReadThread(const std::vector& file_list, std::vector lod_datas; - // first insert tensor for each slots - for (auto& slot : slots) { + // first insert tensor for each sparse_slots + for (auto& slot : sparse_slots) { std::vector lod_data{0}; std::vector batch_feasign; @@ -242,11 +251,11 @@ void ReadThread(const std::vector& file_list, lod_datas.push_back(label_tensor); queue->Push(lod_datas); - VLOG(40) << "push one data, queue_size=" << queue->Size(); + VLOG(4) << "push one data, queue_size=" << queue->Size(); } (*thread_status)[thread_id] = Stopped; - VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; + VLOG(3) << "set status to stopped, thread " << thread_id << " exited"; } } // namespace reader diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 9b2a11bae12d242880829628faa089e1638424b0..68d587bbfc4a60fa31401b0e4fd662f6e8deb4a9 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -36,7 +36,9 @@ namespace reader { enum ReaderThreadStatus { Running, Stopped }; void ReadThread(const std::vector& file_list, - const std::vector& slots, int batch_size, + const std::string& file_type, const std::string& file_format, + const std::vector& dense_slots, + const std::vector& sparse_slots, int batch_size, int thread_id, std::vector* thread_status, std::shared_ptr queue); @@ -47,11 +49,18 @@ void MonitorThread(std::vector* thread_status, class CTRReader : public framework::FileReader { public: - explicit CTRReader(const std::shared_ptr& queue, - int batch_size, int thread_num, - const std::vector& slots, - const std::vector& file_list) - : batch_size_(batch_size), slots_(slots), file_list_(file_list) { + CTRReader(const std::shared_ptr& queue, + int batch_size, int thread_num, const std::string& file_type, + const std::string& file_format, + const std::vector& dense_slots, + const std::vector& sparse_slots, + const std::vector& file_list) + : batch_size_(batch_size), + file_type_(file_type), + file_format_(file_format), + dense_slots_(dense_slots), + sparse_slots_(sparse_slots), + file_list_(file_list) { PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); @@ -97,7 +106,8 @@ class CTRReader : public framework::FileReader { VLOG(3) << "thread_num " << thread_num_; for (int thread_id = 0; thread_id < thread_num_; thread_id++) { read_threads_.emplace_back(new std::thread( - std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, + std::bind(&ReadThread, file_groups_[thread_id], file_type_, + file_format_, dense_slots_, sparse_slots_, batch_size_, thread_id, &read_thread_status_, queue_))); } monitor_thread_.reset(new std::thread( @@ -119,7 +129,10 @@ class CTRReader : public framework::FileReader { private: size_t thread_num_; const int batch_size_; - const std::vector slots_; + const std::string file_type_; + const std::string file_format_; + const std::vector dense_slots_; + const std::vector sparse_slots_; const std::vector file_list_; std::shared_ptr queue_; std::vector> read_threads_; diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index 5e672e9aa186bc3b7b4801f986b80e6ee6dddb1d..734bf45383c74acf1f2cd159440f03dc683403b4 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -132,24 +132,27 @@ TEST(CTR_READER, read_data) { int batch_size = 3; int thread_num = 1; - std::vector slots = {"6002", "6003"}; + std::vector sparse_slots = {"6002", "6003"}; std::vector file_list; for (int i = 0; i < thread_num; ++i) { file_list.push_back(gz_file_name); } - CTRReader reader(queue, batch_size, thread_num, slots, file_list); + CTRReader reader(queue, batch_size, thread_num, "gzip", "plain", {}, + sparse_slots, file_list); reader.Start(); size_t batch_num = std::ceil(static_cast(ctr_data.size()) / batch_size) * thread_num; - check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, - data_slot_6003, batch_num, batch_size, queue, &reader); + check_all_data(ctr_data, sparse_slots, label_dims, label_value, + data_slot_6002, data_slot_6003, batch_num, batch_size, queue, + &reader); reader.Shutdown(); reader.Start(); - check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, - data_slot_6003, batch_num, batch_size, queue, &reader); + check_all_data(ctr_data, sparse_slots, label_dims, label_value, + data_slot_6002, data_slot_6003, batch_num, batch_size, queue, + &reader); reader.Shutdown(); }