From 803e2ed9f47302b84024af89fe0b50f5b24818ba Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 19 Oct 2018 11:34:33 +0800 Subject: [PATCH] add ctr_reader_test and fix bug --- paddle/fluid/operators/reader/CMakeLists.txt | 1 + paddle/fluid/operators/reader/ctr_reader.cc | 68 ++++++++++++++----- paddle/fluid/operators/reader/ctr_reader.h | 16 +++-- .../fluid/operators/reader/ctr_reader_test.cc | 45 ++++++++++++ 4 files changed, 108 insertions(+), 22 deletions(-) create mode 100644 paddle/fluid/operators/reader/ctr_reader_test.cc diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 4ad376c61..2e019f3c1 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -17,6 +17,7 @@ endfunction() cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost gzstream) +cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader) reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index 60e8d1250..55e4975b3 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -46,32 +46,47 @@ static inline void string_split(const std::string& s, const char delimiter, } static inline void parse_line( - const std::string& line, const std::vector& slots, + const std::string& line, + const std::unordered_map& slot_to_index, int64_t* label, - std::unordered_map>* slots_to_data) { + std::unordered_map>* slot_to_data) { std::vector ret; string_split(line, ' ', &ret); *label = std::stoi(ret[2]) > 0; for (size_t i = 3; i < ret.size(); ++i) { const std::string& item = ret[i]; - std::vector slot_and_feasign; - string_split(item, ':', &slot_and_feasign); - if (slot_and_feasign.size() == 2) { - const std::string& slot = slot_and_feasign[1]; - int64_t feasign = std::strtoll(slot_and_feasign[0].c_str(), NULL, 10); - (*slots_to_data)[slot_and_feasign[1]].push_back(feasign); + std::vector feasign_and_slot; + string_split(item, ':', &feasign_and_slot); + auto& slot = feasign_and_slot[1]; + if (feasign_and_slot.size() == 2 && + slot_to_index.find(slot) != slot_to_index.end()) { + const std::string& slot = feasign_and_slot[1]; + int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10); + (*slot_to_data)[feasign_and_slot[1]].push_back(feasign); } } // NOTE:: if the slot has no value, then fill [0] as it's data. - for (auto& slot : slots) { - if (slots_to_data->find(slot) == slots_to_data->end()) { - (*slots_to_data)[slot].push_back(0); + for (auto& item : slot_to_index) { + if (slot_to_data->find(item.first) == slot_to_data->end()) { + (*slot_to_data)[item.first].push_back(0); } } } +static void print_map( + std::unordered_map>* map) { + for (auto it = map->begin(); it != map->end(); ++it) { + std::cout << it->first << " -> "; + std::cout << "["; + for (auto& i : it->second) { + std::cout << i << " "; + } + std::cout << "]\n"; + } +} + class Reader { public: virtual ~Reader() {} @@ -126,7 +141,14 @@ void ReadThread(const std::vector& file_list, const std::vector& slots, int batch_size, int thread_id, std::vector* thread_status, std::shared_ptr queue) { + VLOG(3) << "reader thread start! thread_id = " << thread_id; (*thread_status)[thread_id] = Running; + VLOG(3) << "set status to running"; + + std::unordered_map slot_to_index; + for (size_t i = 0; i < slots.size(); ++i) { + slot_to_index[slots[i]] = i; + } std::string line; @@ -135,21 +157,29 @@ void ReadThread(const std::vector& file_list, MultiGzipReader reader(file_list); + VLOG(3) << "reader inited"; + while (reader.HasNext()) { - // read all files + batch_data.clear(); + batch_label.clear(); + + // read batch_size data for (int i = 0; i < batch_size; ++i) { if (reader.HasNext()) { reader.NextLine(&line); - std::unordered_map> slots_to_data; + std::unordered_map> slot_to_data; int64_t label; - parse_line(line, slots, &label, &slots_to_data); - batch_data.push_back(slots_to_data); + parse_line(line, slot_to_index, &label, &slot_to_data); + batch_data.push_back(slot_to_data); batch_label.push_back(label); } else { break; } } + VLOG(3) << "read one batch, batch_size = " << batch_data.size(); + print_map(&batch_data[0]); + std::vector lod_datas; // first insert tensor for each slots @@ -159,9 +189,9 @@ void ReadThread(const std::vector& file_list, for (size_t i = 0; i < batch_data.size(); ++i) { auto& feasign = batch_data[i][slot]; - lod_data.push_back(lod_data.back() + feasign.size()); - batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end()); + batch_feasign.insert(batch_feasign.end(), feasign.begin(), + feasign.end()); } framework::LoDTensor lod_tensor; @@ -174,6 +204,8 @@ void ReadThread(const std::vector& file_list, lod_datas.push_back(lod_tensor); } + VLOG(3) << "convert data to tensor"; + // insert label tensor framework::LoDTensor label_tensor; int64_t* label_tensor_data = label_tensor.mutable_data( @@ -182,10 +214,12 @@ void ReadThread(const std::vector& file_list, memcpy(label_tensor_data, batch_label.data(), batch_label.size()); lod_datas.push_back(label_tensor); + VLOG(3) << "push one data"; queue->Push(lod_datas); } (*thread_status)[thread_id] = Stopped; + VLOG(3) << "thread " << thread_id << " exited"; } } // namespace reader diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 1006ea96c..9469d86c6 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -47,15 +47,15 @@ class CTRReader : public framework::FileReader { PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); thread_num_ = - file_list_.size() > thread_num_ ? thread_num_ : file_list_.size(); + file_list_.size() > thread_num ? thread_num : file_list_.size(); queue_ = queue; SplitFiles(); - for (int i = 0; i < thread_num; ++i) { + for (int i = 0; i < thread_num_; ++i) { read_thread_status_.push_back(Stopped); } } - ~CTRReader() { queue_->Close(); } + ~CTRReader() { Shutdown(); } void ReadNext(std::vector* out) override { bool success; @@ -74,8 +74,11 @@ class CTRReader : public framework::FileReader { void Start() override { VLOG(3) << "Start reader"; + PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!"); queue_->ReOpen(); - for (int thread_id = 0; thread_id < file_groups_.size(); thread_id++) { + VLOG(3) << "reopen success"; + VLOG(3) << "thread_num " << thread_num_; + for (int thread_id = 0; thread_id < thread_num_; thread_id++) { read_threads_.emplace_back(new std::thread( std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, thread_id, &read_thread_status_, queue_))); @@ -86,7 +89,10 @@ class CTRReader : public framework::FileReader { void SplitFiles() { file_groups_.resize(thread_num_); for (int i = 0; i < file_list_.size(); ++i) { - file_groups_[i % thread_num_].push_back(file_list_[i]); + auto& file_name = file_list_[i]; + std::ifstream f(file_name.c_str()); + PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); + file_groups_[i % thread_num_].push_back(file_name); } } diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc new file mode 100644 index 000000000..404da3c6c --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/reader/blocking_queue.h" +#include "paddle/fluid/operators/reader/ctr_reader.h" + +using paddle::operators::reader::LoDTensorBlockingQueue; +using paddle::operators::reader::LoDTensorBlockingQueueHolder; +using paddle::operators::reader::CTRReader; + +TEST(CTR_READER, read_data) { + LoDTensorBlockingQueueHolder queue_holder; + int capacity = 64; + queue_holder.InitOnce(capacity, {}, false); + + std::shared_ptr queue = queue_holder.GetQueue(); + + int batch_size = 10; + int thread_num = 1; + std::vector slots = {"6003", "6004"}; + std::vector file_list = { + "/Users/qiaolongfei/project/gzip_test/part-00000-A.gz", + "/Users/qiaolongfei/project/gzip_test/part-00000-A.gz"}; + + CTRReader reader(queue, batch_size, thread_num, slots, file_list); + + reader.Start(); + // + // std::vector out; + // reader.ReadNext(&out); +} -- GitLab