提交 803e2ed9 编写于 作者: Q Qiao Longfei

add ctr_reader_test and fix bug

上级 c8bd5210
......@@ -17,6 +17,7 @@ endfunction()
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
cc_library(ctr_reader SRCS ctr_reader.cc DEPS reader simple_threadpool boost gzstream)
cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader)
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
......
......@@ -46,29 +46,44 @@ static inline void string_split(const std::string& s, const char delimiter,
}
static inline void parse_line(
const std::string& line, const std::vector<std::string>& slots,
const std::string& line,
const std::unordered_map<std::string, size_t>& slot_to_index,
int64_t* label,
std::unordered_map<std::string, std::vector<int64_t>>* slots_to_data) {
std::unordered_map<std::string, std::vector<int64_t>>* slot_to_data) {
std::vector<std::string> ret;
string_split(line, ' ', &ret);
*label = std::stoi(ret[2]) > 0;
for (size_t i = 3; i < ret.size(); ++i) {
const std::string& item = ret[i];
std::vector<std::string> slot_and_feasign;
string_split(item, ':', &slot_and_feasign);
if (slot_and_feasign.size() == 2) {
const std::string& slot = slot_and_feasign[1];
int64_t feasign = std::strtoll(slot_and_feasign[0].c_str(), NULL, 10);
(*slots_to_data)[slot_and_feasign[1]].push_back(feasign);
std::vector<std::string> feasign_and_slot;
string_split(item, ':', &feasign_and_slot);
auto& slot = feasign_and_slot[1];
if (feasign_and_slot.size() == 2 &&
slot_to_index.find(slot) != slot_to_index.end()) {
const std::string& slot = feasign_and_slot[1];
int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10);
(*slot_to_data)[feasign_and_slot[1]].push_back(feasign);
}
}
// NOTE:: if the slot has no value, then fill [0] as it's data.
for (auto& slot : slots) {
if (slots_to_data->find(slot) == slots_to_data->end()) {
(*slots_to_data)[slot].push_back(0);
for (auto& item : slot_to_index) {
if (slot_to_data->find(item.first) == slot_to_data->end()) {
(*slot_to_data)[item.first].push_back(0);
}
}
}
static void print_map(
std::unordered_map<std::string, std::vector<int64_t>>* map) {
for (auto it = map->begin(); it != map->end(); ++it) {
std::cout << it->first << " -> ";
std::cout << "[";
for (auto& i : it->second) {
std::cout << i << " ";
}
std::cout << "]\n";
}
}
......@@ -126,7 +141,14 @@ void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "reader thread start! thread_id = " << thread_id;
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < slots.size(); ++i) {
slot_to_index[slots[i]] = i;
}
std::string line;
......@@ -135,21 +157,29 @@ void ReadThread(const std::vector<std::string>& file_list,
MultiGzipReader reader(file_list);
VLOG(3) << "reader inited";
while (reader.HasNext()) {
// read all files
batch_data.clear();
batch_label.clear();
// read batch_size data
for (int i = 0; i < batch_size; ++i) {
if (reader.HasNext()) {
reader.NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slots_to_data;
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
int64_t label;
parse_line(line, slots, &label, &slots_to_data);
batch_data.push_back(slots_to_data);
parse_line(line, slot_to_index, &label, &slot_to_data);
batch_data.push_back(slot_to_data);
batch_label.push_back(label);
} else {
break;
}
}
VLOG(3) << "read one batch, batch_size = " << batch_data.size();
print_map(&batch_data[0]);
std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots
......@@ -159,9 +189,9 @@ void ReadThread(const std::vector<std::string>& file_list,
for (size_t i = 0; i < batch_data.size(); ++i) {
auto& feasign = batch_data[i][slot];
lod_data.push_back(lod_data.back() + feasign.size());
batch_feasign.insert(feasign.end(), feasign.begin(), feasign.end());
batch_feasign.insert(batch_feasign.end(), feasign.begin(),
feasign.end());
}
framework::LoDTensor lod_tensor;
......@@ -174,6 +204,8 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas.push_back(lod_tensor);
}
VLOG(3) << "convert data to tensor";
// insert label tensor
framework::LoDTensor label_tensor;
int64_t* label_tensor_data = label_tensor.mutable_data<int64_t>(
......@@ -182,10 +214,12 @@ void ReadThread(const std::vector<std::string>& file_list,
memcpy(label_tensor_data, batch_label.data(), batch_label.size());
lod_datas.push_back(label_tensor);
VLOG(3) << "push one data";
queue->Push(lod_datas);
}
(*thread_status)[thread_id] = Stopped;
VLOG(3) << "thread " << thread_id << " exited";
}
} // namespace reader
......
......@@ -47,15 +47,15 @@ class CTRReader : public framework::FileReader {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
thread_num_ =
file_list_.size() > thread_num_ ? thread_num_ : file_list_.size();
file_list_.size() > thread_num ? thread_num : file_list_.size();
queue_ = queue;
SplitFiles();
for (int i = 0; i < thread_num; ++i) {
for (int i = 0; i < thread_num_; ++i) {
read_thread_status_.push_back(Stopped);
}
}
~CTRReader() { queue_->Close(); }
~CTRReader() { Shutdown(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success;
......@@ -74,8 +74,11 @@ class CTRReader : public framework::FileReader {
void Start() override {
VLOG(3) << "Start reader";
PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!");
queue_->ReOpen();
for (int thread_id = 0; thread_id < file_groups_.size(); thread_id++) {
VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
thread_id, &read_thread_status_, queue_)));
......@@ -86,7 +89,10 @@ class CTRReader : public framework::FileReader {
void SplitFiles() {
file_groups_.resize(thread_num_);
for (int i = 0; i < file_list_.size(); ++i) {
file_groups_[i % thread_num_].push_back(file_list_[i]);
auto& file_name = file_list_[i];
std::ifstream f(file_name.c_str());
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
file_groups_[i % thread_num_].push_back(file_name);
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/ctr_reader.h"
using paddle::operators::reader::LoDTensorBlockingQueue;
using paddle::operators::reader::LoDTensorBlockingQueueHolder;
using paddle::operators::reader::CTRReader;
TEST(CTR_READER, read_data) {
LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64;
queue_holder.InitOnce(capacity, {}, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
int batch_size = 10;
int thread_num = 1;
std::vector<std::string> slots = {"6003", "6004"};
std::vector<std::string> file_list = {
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz",
"/Users/qiaolongfei/project/gzip_test/part-00000-A.gz"};
CTRReader reader(queue, batch_size, thread_num, slots, file_list);
reader.Start();
//
// std::vector<LoDTensor> out;
// reader.ReadNext(&out);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册