未验证 提交 c5855506 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #14731 from jacquesqiao/optimize-cpp-reader

Optimize cpp reader
...@@ -359,6 +359,7 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b ...@@ -359,6 +359,7 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.contrib.reader.ctr_reader.ctr_reader ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) paddle.fluid.contrib.build_compressor ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None))
paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)) paddle.fluid.contrib.CompressPass.__init__ ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None))
paddle.fluid.contrib.CompressPass.add_strategy ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.CompressPass.add_strategy ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None)
......
...@@ -41,13 +41,19 @@ class CreateCTRReaderOp : public framework::OperatorBase { ...@@ -41,13 +41,19 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
int thread_num = Attr<int>("thread_num"); auto thread_num = Attr<int>("thread_num");
std::vector<std::string> slots = Attr<std::vector<std::string>>("slots"); auto sparse_slots = Attr<std::vector<std::string>>("sparse_slots");
int batch_size = Attr<int>("batch_size"); auto dense_slot_index = Attr<std::vector<int>>("dense_slot_index");
std::vector<std::string> file_list = auto sparse_slot_index = Attr<std::vector<int>>("sparse_slot_index");
Attr<std::vector<std::string>>("file_list"); auto batch_size = Attr<int>("batch_size");
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), batch_size, auto file_type = Attr<std::string>("file_type");
thread_num, slots, file_list)); auto file_format = Attr<std::string>("file_format");
auto file_list = Attr<std::vector<std::string>>("file_list");
DataDesc data_desc(batch_size, file_list, file_type, file_format,
dense_slot_index, sparse_slot_index, sparse_slots);
VLOG(1) << data_desc;
out->Reset(std::make_shared<CTRReader>(queue_holder->GetQueue(), thread_num,
data_desc));
} }
}; };
...@@ -58,10 +64,22 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase { ...@@ -58,10 +64,22 @@ class CreateCTRReaderOpMaker : public FileReaderMakerBase {
"Name of the `LoDTensorBlockingQueueHolder` variable"); "Name of the `LoDTensorBlockingQueueHolder` variable");
AddAttr<int>("thread_num", "the thread num to read data"); AddAttr<int>("thread_num", "the thread num to read data");
AddAttr<int>("batch_size", "the batch size of read data"); AddAttr<int>("batch_size", "the batch size of read data");
AddAttr<std::string>("file_type", "plain or gzip").SetDefault("plain");
AddAttr<std::string>("file_format", "svm or csv").SetDefault("csv");
AddAttr<std::vector<std::string>>("file_list", AddAttr<std::vector<std::string>>("file_list",
"The list of files that need to read"); "The list of files that need to read");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<int>>(
"slots", "the slots that should be extract from file"); "dense_slot_index",
"the dense slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<int>>(
"sparse_slot_index",
"the sparse slots id that should be extract from file")
.SetDefault({});
AddAttr<std::vector<std::string>>("sparse_slots",
"the sparse slots id that should be "
"extract from file, used when file "
"format is svm");
AddComment(R"DOC( AddComment(R"DOC(
Create CTRReader to support read ctr data with cpp. Create CTRReader to support read ctr data with cpp.
......
...@@ -73,6 +73,9 @@ static inline void parse_line( ...@@ -73,6 +73,9 @@ static inline void parse_line(
} }
} }
// label slot1:fea_sign slot2:fea_sign slot1:fea_sign
static inline void parse_svm_line(const std::string& line) {}
class Reader { class Reader {
public: public:
virtual ~Reader() {} virtual ~Reader() {}
...@@ -95,11 +98,27 @@ class GzipReader : public Reader { ...@@ -95,11 +98,27 @@ class GzipReader : public Reader {
igzstream gzstream_; igzstream gzstream_;
}; };
class MultiGzipReader : public Reader { class PlainFileReader : public Reader {
public:
explicit PlainFileReader(const std::string& file_name)
: stream_(file_name.c_str()) {}
~PlainFileReader() {}
bool HasNext() override { return stream_.peek() != EOF; }
void NextLine(std::string* line) override { std::getline(stream_, *line); }
private:
std::ifstream stream_;
};
template <typename SingleFileReader>
class MultiFileReader : public Reader {
public: public:
explicit MultiGzipReader(const std::vector<std::string>& file_list) { explicit MultiFileReader(const std::vector<std::string>& file_list) {
for (auto& file : file_list) { for (auto& file : file_list) {
readers_.emplace_back(std::make_shared<GzipReader>(file)); readers_.emplace_back(std::make_shared<SingleFileReader>(file));
} }
} }
...@@ -119,46 +138,35 @@ class MultiGzipReader : public Reader { ...@@ -119,46 +138,35 @@ class MultiGzipReader : public Reader {
} }
private: private:
std::vector<std::shared_ptr<GzipReader>> readers_; std::vector<std::shared_ptr<SingleFileReader>> readers_;
size_t current_reader_index_ = 0; size_t current_reader_index_ = 0;
}; };
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "monitor thread in"; VLOG(3) << "monitor thread in";
bool reader_thread_is_running = true; bool reader_thread_is_running = true;
while (reader_thread_is_running) { while (reader_thread_is_running) {
VLOG(30) << "reader_thread_is_running"; VLOG(3) << "reader_thread_is_running";
reader_thread_is_running = false; reader_thread_is_running = false;
for (size_t i = 0; i < (*thread_status).size(); ++i) { for (size_t i = 0; i < (*thread_status).size(); ++i) {
if ((*thread_status)[i] == Running) { if ((*thread_status)[i] == Running) {
VLOG(30) << "reader is running!"; VLOG(3) << "reader is running!";
reader_thread_is_running = true; reader_thread_is_running = true;
} }
} }
std::this_thread::sleep_for(std::chrono::milliseconds(1000)); std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(30) << "all reader thread is stopped, push empty data into queue"; VLOG(3) << "all reader thread is stopped, close the queue";
queue->Push({}); queue->Close();
VLOG(30) << "monitor thread exited"; VLOG(3) << "monitor thread exited";
} }
void ReadThread(const std::vector<std::string>& file_list, void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
const std::vector<std::string>& slots, int batch_size,
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) { std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(30) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(30) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(30) << "set status to running";
std::unordered_map<std::string, size_t> slot_to_index; std::unordered_map<std::string, size_t> slot_to_index;
for (size_t i = 0; i < slots.size(); ++i) { for (size_t i = 0; i < data_desc.sparse_slot_ids_.size(); ++i) {
slot_to_index[slots[i]] = i; slot_to_index[data_desc.sparse_slot_ids_[i]] = i;
} }
std::string line; std::string line;
...@@ -166,21 +174,17 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -166,21 +174,17 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data; std::vector<std::unordered_map<std::string, std::vector<int64_t>>> batch_data;
std::vector<int64_t> batch_label; std::vector<int64_t> batch_label;
MultiGzipReader reader(file_list); while (reader->HasNext()) {
VLOG(30) << "reader inited";
while (reader.HasNext()) {
batch_data.clear(); batch_data.clear();
batch_data.reserve(batch_size); batch_data.reserve(data_desc.batch_size_);
batch_label.clear(); batch_label.clear();
batch_label.reserve(batch_size); batch_label.reserve(data_desc.batch_size_);
// read batch_size data // read batch_size data
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader.HasNext()) { if (reader->HasNext()) {
reader.NextLine(&line); reader->NextLine(&line);
std::unordered_map<std::string, std::vector<int64_t>> slot_to_data; std::unordered_map<std::string, std::vector<int64_t>> slot_to_data;
int64_t label; int64_t label;
parse_line(line, slot_to_index, &label, &slot_to_data); parse_line(line, slot_to_index, &label, &slot_to_data);
...@@ -193,8 +197,8 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -193,8 +197,8 @@ void ReadThread(const std::vector<std::string>& file_list,
std::vector<framework::LoDTensor> lod_datas; std::vector<framework::LoDTensor> lod_datas;
// first insert tensor for each slots // first insert tensor for each sparse_slots
for (auto& slot : slots) { for (auto& slot : data_desc.sparse_slot_ids_) {
std::vector<size_t> lod_data{0}; std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign; std::vector<int64_t> batch_feasign;
...@@ -226,11 +230,167 @@ void ReadThread(const std::vector<std::string>& file_list, ...@@ -226,11 +230,167 @@ void ReadThread(const std::vector<std::string>& file_list,
lod_datas.push_back(label_tensor); lod_datas.push_back(label_tensor);
queue->Push(lod_datas); queue->Push(lod_datas);
VLOG(40) << "push one data, queue_size=" << queue->Size(); VLOG(4) << "push one data, queue_size=" << queue->Size();
}
}
// label dense_fea,dense_fea sparse_fea,sparse_fea
static inline void parse_csv_line(
const std::string& line, const DataDesc& data_desc, int64_t* label,
std::vector<std::vector<float>>* dense_datas,
std::vector<std::vector<int64_t>>* sparse_datas) {
std::vector<std::string> ret;
string_split(line, ' ', &ret);
*label = std::stol(ret[0]);
dense_datas->resize(data_desc.dense_slot_index_.size());
for (size_t i = 0; i < data_desc.dense_slot_index_.size(); ++i) {
int slot_idx = data_desc.dense_slot_index_[i];
auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str;
string_split(slot_data, ',', &data_in_slot_str);
std::vector<float> data_in_slot;
for (auto& data_str : data_in_slot_str) {
(*dense_datas)[i].push_back(std::stof(data_str));
}
}
sparse_datas->resize(data_desc.sparse_slot_index_.size());
for (size_t i = 0; i < data_desc.sparse_slot_index_.size(); ++i) {
int slot_idx = data_desc.sparse_slot_index_[i];
auto& slot_data = ret[slot_idx];
std::vector<std::string> data_in_slot_str;
string_split(slot_data, ',', &data_in_slot_str);
std::vector<int64_t> data_in_slot;
for (auto& data_str : data_in_slot_str) {
auto id = std::stol(data_str);
(*sparse_datas)[i].push_back(id);
}
}
}
void ReadCsvData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
std::string line;
while (reader->HasNext()) {
std::vector<int64_t> batch_label;
batch_label.reserve(data_desc.batch_size_);
std::vector<std::vector<std::vector<float>>> batch_dense_data;
batch_dense_data.reserve(data_desc.batch_size_);
std::vector<std::vector<std::vector<int64_t>>> batch_sparse_data;
batch_sparse_data.reserve(data_desc.batch_size_);
// read batch_size data
for (int i = 0; i < data_desc.batch_size_; ++i) {
if (reader->HasNext()) {
reader->NextLine(&line);
int64_t label;
std::vector<std::vector<float>> dense_datas;
std::vector<std::vector<int64_t>> sparse_datas;
parse_csv_line(line, data_desc, &label, &dense_datas, &sparse_datas);
batch_label.push_back(label);
if (!batch_dense_data.empty()) {
PADDLE_ENFORCE_EQ(batch_dense_data[0].size(), dense_datas.size(),
"dense data should have the same shape");
}
batch_dense_data.push_back(dense_datas);
batch_sparse_data.push_back(sparse_datas);
} else {
break;
}
}
// the order of output data is label, dense_datas, sparse_datas
std::vector<framework::LoDTensor> lod_datas;
// insert label tensor
framework::LoDTensor label_tensor;
auto* label_tensor_data = label_tensor.mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(batch_label.size()), 1}),
platform::CPUPlace());
memcpy(label_tensor_data, batch_label.data(),
batch_label.size() * sizeof(int64_t));
lod_datas.push_back(label_tensor);
// insert tensor for each dense_slots
for (size_t i = 0; i < data_desc.dense_slot_index_.size(); ++i) {
framework::LoDTensor lod_tensor;
size_t width = batch_dense_data[0][i].size();
auto* tensor_data = lod_tensor.mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(batch_dense_data.size()), // batch_size
static_cast<int64_t>(width)}),
platform::CPUPlace());
for (size_t j = 0; j < batch_dense_data.size(); ++j) {
auto& dense_data_row = batch_dense_data[j][i];
memcpy(tensor_data + j * width, dense_data_row.data(),
width * sizeof(float));
}
lod_datas.push_back(lod_tensor);
}
// insert tensor for each sparse_slots
for (size_t i = 0; i < data_desc.sparse_slot_index_.size(); ++i) {
std::vector<size_t> lod_data{0};
std::vector<int64_t> batch_feasign;
for (size_t row_idx = 0; row_idx < batch_sparse_data.size(); ++row_idx) {
auto& sparse_ids = batch_sparse_data[row_idx][i];
lod_data.push_back(lod_data.back() + sparse_ids.size());
batch_feasign.insert(batch_feasign.end(), sparse_ids.begin(),
sparse_ids.end());
}
framework::LoDTensor lod_tensor;
framework::LoD lod{lod_data};
lod_tensor.set_lod(lod);
int64_t* tensor_data = lod_tensor.mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(batch_feasign.size()), 1}),
platform::CPUPlace());
memcpy(tensor_data, batch_feasign.data(),
batch_feasign.size() * sizeof(int64_t));
lod_datas.push_back(lod_tensor);
}
queue->Push(lod_datas);
VLOG(4) << "push one data, queue_size=" << queue->Size();
}
}
void ReadThread(const std::vector<std::string>& file_list,
const DataDesc& data_desc, int thread_id,
std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue) {
VLOG(3) << "[" << thread_id << "]"
<< " reader thread start! thread_id = " << thread_id;
for (auto& file : file_list) {
VLOG(3) << "[" << thread_id << "]"
<< " file " << file;
}
(*thread_status)[thread_id] = Running;
VLOG(3) << "set status to running";
std::shared_ptr<Reader> reader;
if (data_desc.file_type_ == "gzip") {
reader.reset(new MultiFileReader<GzipReader>(file_list));
} else if (data_desc.file_type_ == "plain") {
reader.reset(new MultiFileReader<PlainFileReader>(file_list));
} else {
PADDLE_THROW("do not support file format %s", data_desc.file_type_);
}
VLOG(3) << "reader inited";
if (data_desc.file_format_ == "svm") {
ReadSvmData(data_desc, reader, queue);
} else if (data_desc.file_format_ == "csv") {
ReadCsvData(data_desc, reader, queue);
} }
(*thread_status)[thread_id] = Stopped; (*thread_status)[thread_id] = Stopped;
VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; VLOG(3) << "set status to stopped, thread " << thread_id << " exited";
} }
} // namespace reader } // namespace reader
......
...@@ -36,9 +36,63 @@ namespace reader { ...@@ -36,9 +36,63 @@ namespace reader {
enum ReaderThreadStatus { Running, Stopped }; enum ReaderThreadStatus { Running, Stopped };
struct DataDesc {
DataDesc(int batch_size, const std::vector<std::string>& file_names,
const std::string& file_type, const std::string& file_format,
const std::vector<int>& dense_slot_index,
const std::vector<int>& sparse_slot_index,
const std::vector<std::string>& sparse_slot_ids)
: batch_size_(batch_size),
file_names_(file_names),
file_type_(file_type),
file_format_(file_format),
dense_slot_index_(dense_slot_index),
sparse_slot_index_(sparse_slot_index),
sparse_slot_ids_(sparse_slot_ids) {}
const int batch_size_;
const std::vector<std::string> file_names_;
const std::string file_type_; // gzip or plain
const std::string file_format_; // csv or svm
// used for csv data format
const std::vector<int> dense_slot_index_;
const std::vector<int> sparse_slot_index_;
// used for svm data format
const std::vector<std::string> sparse_slot_ids_;
};
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
os << "data_desc:\n";
os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
os << "\tfile_type -> " << data_desc.file_type_ << "\n";
os << "\tfile_format -> " << data_desc.file_format_ << "\n";
os << "\tfile_names -> {";
for (auto& file_name : data_desc.file_names_) {
os << file_name << ",";
}
os << "}\n";
os << "\tdense_slot_index -> {";
for (auto& slot : data_desc.dense_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_index_ -> {";
for (auto& slot : data_desc.sparse_slot_index_) {
os << slot << ",";
}
os << "}\n";
os << "\tsparse_slot_ids_ -> {";
for (auto& slot : data_desc.sparse_slot_ids_) {
os << slot << ",";
}
os << "}\n";
return os;
}
void ReadThread(const std::vector<std::string>& file_list, void ReadThread(const std::vector<std::string>& file_list,
const std::vector<std::string>& slots, int batch_size, const DataDesc& data_desc, int thread_id,
int thread_id, std::vector<ReaderThreadStatus>* thread_status, std::vector<ReaderThreadStatus>* thread_status,
std::shared_ptr<LoDTensorBlockingQueue> queue); std::shared_ptr<LoDTensorBlockingQueue> queue);
// monitor all running thread, if they are all stopped, // monitor all running thread, if they are all stopped,
...@@ -48,15 +102,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status, ...@@ -48,15 +102,15 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class CTRReader : public framework::FileReader { class CTRReader : public framework::FileReader {
public: public:
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue, CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
int batch_size, size_t thread_num, int thread_num, const DataDesc& data_desc)
const std::vector<std::string>& slots, : data_desc_(data_desc) {
const std::vector<std::string>& file_list)
: batch_size_(batch_size), slots_(slots), file_list_(file_list) {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
thread_num_ = std::min<size_t>(file_list_.size(), thread_num); "file list should not be empty");
thread_num_ = std::min<size_t>(data_desc_.file_names_.size(), thread_num);
queue_ = queue; queue_ = queue;
SplitFiles(); SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) { for (size_t i = 0; i < thread_num_; ++i) {
...@@ -64,7 +118,7 @@ class CTRReader : public framework::FileReader { ...@@ -64,7 +118,7 @@ class CTRReader : public framework::FileReader {
} }
} }
~CTRReader() {} ~CTRReader() { Shutdown(); }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
...@@ -81,7 +135,10 @@ class CTRReader : public framework::FileReader { ...@@ -81,7 +135,10 @@ class CTRReader : public framework::FileReader {
for (auto& read_thread : read_threads_) { for (auto& read_thread : read_threads_) {
read_thread->join(); read_thread->join();
} }
if (monitor_thread_) {
monitor_thread_->join(); monitor_thread_->join();
}
read_threads_.clear(); read_threads_.clear();
monitor_thread_.reset(nullptr); monitor_thread_.reset(nullptr);
...@@ -95,9 +152,9 @@ class CTRReader : public framework::FileReader { ...@@ -95,9 +152,9 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen(); queue_->ReOpen();
VLOG(3) << "reopen success"; VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) { for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(std::bind( read_threads_.emplace_back(new std::thread(std::bind(
&ReadThread, file_groups_[thread_id], slots_, batch_size_, &ReadThread, file_groups_[thread_id], data_desc_,
static_cast<int>(thread_id), &read_thread_status_, queue_))); static_cast<int>(thread_id), &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread( monitor_thread_.reset(new std::thread(
...@@ -108,8 +165,8 @@ class CTRReader : public framework::FileReader { ...@@ -108,8 +165,8 @@ class CTRReader : public framework::FileReader {
private: private:
void SplitFiles() { void SplitFiles() {
file_groups_.resize(thread_num_); file_groups_.resize(thread_num_);
for (size_t i = 0; i < file_list_.size(); ++i) { for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
auto& file_name = file_list_[i]; auto& file_name = data_desc_.file_names_[i];
std::ifstream f(file_name.c_str()); std::ifstream f(file_name.c_str());
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
file_groups_[i % thread_num_].push_back(file_name); file_groups_[i % thread_num_].push_back(file_name);
...@@ -118,9 +175,7 @@ class CTRReader : public framework::FileReader { ...@@ -118,9 +175,7 @@ class CTRReader : public framework::FileReader {
private: private:
size_t thread_num_; size_t thread_num_;
const int batch_size_; const DataDesc data_desc_;
const std::vector<std::string> slots_;
const std::vector<std::string> file_list_;
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
std::vector<std::unique_ptr<std::thread>> read_threads_; std::vector<std::unique_ptr<std::thread>> read_threads_;
std::unique_ptr<std::thread> monitor_thread_; std::unique_ptr<std::thread> monitor_thread_;
......
...@@ -36,6 +36,7 @@ using paddle::framework::LoD; ...@@ -36,6 +36,7 @@ using paddle::framework::LoD;
using paddle::framework::DDim; using paddle::framework::DDim;
using paddle::platform::CPUPlace; using paddle::platform::CPUPlace;
using paddle::framework::make_ddim; using paddle::framework::make_ddim;
using paddle::operators::reader::DataDesc;
static void generatedata(const std::vector<std::string>& data, static void generatedata(const std::vector<std::string>& data,
const std::string& file_name) { const std::string& file_name) {
...@@ -126,30 +127,103 @@ TEST(CTR_READER, read_data) { ...@@ -126,30 +127,103 @@ TEST(CTR_READER, read_data) {
LoDTensorBlockingQueueHolder queue_holder; LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64; int capacity = 64;
queue_holder.InitOnce(capacity, {}, false); queue_holder.InitOnce(capacity, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue(); std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
int batch_size = 3; int batch_size = 3;
int thread_num = 1; int thread_num = 1;
std::vector<std::string> slots = {"6002", "6003"}; std::vector<std::string> sparse_slots = {"6002", "6003"};
std::vector<std::string> file_list; std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) { for (int i = 0; i < thread_num; ++i) {
file_list.push_back(gz_file_name); file_list.push_back(gz_file_name);
} }
CTRReader reader(queue, batch_size, thread_num, slots, file_list); DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {},
sparse_slots);
CTRReader reader(queue, thread_num, data_desc);
reader.Start(); reader.Start();
size_t batch_num = size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num; std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown(); reader.Shutdown();
reader.Start(); reader.Start();
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, check_all_data(ctr_data, sparse_slots, label_dims, label_value,
data_slot_6003, batch_num, batch_size, queue, &reader); data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
&reader);
reader.Shutdown();
}
static void GenereteCsvData(const std::string& file_name,
const std::vector<std::string>& data) {
std::ofstream out(file_name.c_str());
PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name);
for (auto& c : data) {
out << c;
}
out.close();
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
}
static void CheckReadCsvOut(const std::vector<LoDTensor>& out) {
ASSERT_EQ(out.size(), 3);
ASSERT_EQ(out[0].dims()[1], 1);
ASSERT_EQ(out[1].dims()[1], 2);
ASSERT_EQ(out[2].dims()[1], 1);
for (size_t i = 0; i < out[0].numel(); ++i) {
int64_t label = out[0].data<int64_t>()[i];
auto& dense_dim = out[1].dims();
for (size_t j = 0; j < dense_dim[1]; ++j) {
ASSERT_EQ(out[1].data<float>()[i * dense_dim[1] + j],
static_cast<float>(label + 0.1));
}
auto& sparse_lod = out[2].lod();
for (size_t j = sparse_lod[0][i]; j < sparse_lod[0][i + 1]; ++j) {
ASSERT_EQ(out[2].data<int64_t>()[j], label);
}
}
}
TEST(CTR_READER, read_csv_data) {
std::string file_name = "test_ctr_reader_data.csv";
const std::vector<std::string> csv_data = {
"0 0.1,0.1 0,0,0,0\n", "1 1.1,1.1 1,1,1,1\n", "2 2.1,2.1 2,2,2,2\n",
"3 3.1,3.1 3,3,3,3\n",
};
GenereteCsvData(file_name, csv_data);
LoDTensorBlockingQueueHolder queue_holder;
int capacity = 64;
queue_holder.InitOnce(capacity, false);
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
int batch_size = 3;
int thread_num = 1;
std::vector<std::string> file_list;
for (int i = 0; i < thread_num; ++i) {
file_list.push_back(file_name);
}
DataDesc data_desc(batch_size, file_list, "plain", "csv", {1}, {2}, {});
CTRReader reader(queue, thread_num, data_desc);
for (size_t i = 0; i < 2; ++i) {
reader.Start();
std::vector<LoDTensor> out;
while (true) {
reader.ReadNext(&out);
if (out.empty()) {
break;
}
CheckReadCsvOut(out);
}
reader.Shutdown(); reader.Shutdown();
}
} }
...@@ -32,10 +32,8 @@ class LoDTensorBlockingQueue { ...@@ -32,10 +32,8 @@ class LoDTensorBlockingQueue {
friend class LoDTensorBlockingQueueHolder; friend class LoDTensorBlockingQueueHolder;
private: private:
LoDTensorBlockingQueue(size_t capacity, explicit LoDTensorBlockingQueue(size_t capacity, bool speed_test_mode = false)
const std::vector<framework::DDim>& dims, : queue_(capacity, speed_test_mode) {}
bool speed_test_mode = false)
: queue_(capacity, speed_test_mode), dims_(dims) {}
public: public:
bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) { bool Push(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
...@@ -65,17 +63,15 @@ class LoDTensorBlockingQueue { ...@@ -65,17 +63,15 @@ class LoDTensorBlockingQueue {
private: private:
BlockingQueue<std::vector<framework::LoDTensor>> queue_; BlockingQueue<std::vector<framework::LoDTensor>> queue_;
std::vector<framework::DDim> dims_;
}; };
class LoDTensorBlockingQueueHolder { class LoDTensorBlockingQueueHolder {
public: public:
void InitOnce(size_t capacity, const std::vector<framework::DDim>& dims, void InitOnce(size_t capacity, bool speed_test_mode = false) {
bool speed_test_mode = false) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
queue_ == nullptr, queue_ == nullptr,
"LoDTensorBlockingQueueHolder::InitOnce() can only be called once"); "LoDTensorBlockingQueueHolder::InitOnce() can only be called once");
queue_.reset(new LoDTensorBlockingQueue(capacity, dims, speed_test_mode)); queue_.reset(new LoDTensorBlockingQueue(capacity, speed_test_mode));
} }
inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const { inline const std::shared_ptr<LoDTensorBlockingQueue>& GetQueue() const {
......
...@@ -27,13 +27,13 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -27,13 +27,13 @@ class ReadInferShape : public framework::InferShapeBase {
"The ReadOp must take a reader as input."); "The ReadOp must take a reader as input.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"The ReadOp should be assigned with output."); "The ReadOp should be assigned with output.");
if (!ctx->IsRuntime() && ctx->Attrs().Get<bool>("infer_out")) {
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader"); std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
std::vector<std::string> out_names = ctx->Outputs("Out"); std::vector<std::string> out_names = ctx->Outputs("Out");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
reader_dims.size(), out_names.size(), reader_dims.size(), out_names.size(),
"The reader's dim number doesn't match the output number."); "The reader's dim number doesn't match the output number.");
ctx->SetOutputsDim("Out", reader_dims); ctx->SetOutputsDim("Out", reader_dims);
if (!ctx->IsRuntime()) {
auto in_desc = auto in_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]); boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Reader")[0]);
auto in_lod_levels = in_desc->GetLoDLevels(); auto in_lod_levels = in_desc->GetLoDLevels();
...@@ -53,6 +53,8 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -53,6 +53,8 @@ class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc* block) const override {
bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
if (infer_out) {
std::string reader_name = op_desc.Input("Reader")[0]; std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out"); std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc* reader = block->FindVarRecursive(reader_name); framework::VarDesc* reader = block->FindVarRecursive(reader_name);
...@@ -64,6 +66,7 @@ class ReadInferVarType : public framework::VarTypeInference { ...@@ -64,6 +66,7 @@ class ReadInferVarType : public framework::VarTypeInference {
out.SetDataType(dtypes[i]); out.SetDataType(dtypes[i]);
} }
} }
}
}; };
class ReadOp : public framework::OperatorBase { class ReadOp : public framework::OperatorBase {
...@@ -73,6 +76,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -73,6 +76,7 @@ class ReadOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
VLOG(3) << "read op in";
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
detail::Ref(scope.FindVar(Input("Reader")), detail::Ref(scope.FindVar(Input("Reader")),
"Cannot find reader variable %s", Input("Reader")) "Cannot find reader variable %s", Input("Reader"))
...@@ -87,7 +91,9 @@ class ReadOp : public framework::OperatorBase { ...@@ -87,7 +91,9 @@ class ReadOp : public framework::OperatorBase {
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
VLOG(3) << "read empty data in";
if (Attr<bool>("throw_eof_exp")) { if (Attr<bool>("throw_eof_exp")) {
VLOG(3) << "throw_eof_exp";
PADDLE_THROW_EOF(); PADDLE_THROW_EOF();
} else { } else {
ins.resize(out_arg_names.size()); ins.resize(out_arg_names.size());
...@@ -96,6 +102,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -96,6 +102,7 @@ class ReadOp : public framework::OperatorBase {
tensor.mutable_data<float>(framework::make_ddim({0}), dev_place); tensor.mutable_data<float>(framework::make_ddim({0}), dev_place);
} }
} }
VLOG(3) << "read empty data out";
} }
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) { for (size_t i = 0; i < out_arg_names.size(); ++i) {
...@@ -120,6 +127,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,6 +127,7 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" only when the data-balance is enabled in ParallelExecutor" " only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.") " and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() { ...@@ -65,6 +65,10 @@ void FileReaderMakerBase::Make() {
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data."); AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<bool>(
"use_data_config",
"Use the config of all datas like shape_concat/ranks/lod_levels")
.SetDefault(true);
Apply(); Apply();
} }
...@@ -75,7 +79,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -75,7 +79,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null."); "The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat"); bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
if (use_data_config) {
const auto shape_concat =
ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks"); const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes); ctx->SetReaderDims("Out", shapes);
...@@ -88,6 +95,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -88,6 +95,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
framework::VarDesc* reader = framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels); reader->SetLoDLevels(lod_levels);
}
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
......
...@@ -485,6 +485,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -485,6 +485,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("start", &framework::ReaderHolder::Start)
.def("reset", &framework::ReaderHolder::ResetAll); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
...@@ -505,17 +506,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -505,17 +506,10 @@ All parameter, weight, gradient are variables in Paddle.
.def("is_closed", &LoDTensorBlockingQueue::IsClosed); .def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](Variable &var, size_t capacity, [](Variable &var,
const std::vector<std::vector<int64_t>> &shapes) size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
-> std::shared_ptr<LoDTensorBlockingQueue> {
std::vector<DDim> dims(shapes.size());
std::transform(shapes.begin(), shapes.end(), dims.begin(),
[](const std::vector<int64_t> &shape) {
return make_ddim(shape);
});
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>(); auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, dims, holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue(); return holder->GetQueue();
}, },
py::return_value_policy::copy); py::return_value_policy::copy);
......
...@@ -22,6 +22,8 @@ from . import op_frequence ...@@ -22,6 +22,8 @@ from . import op_frequence
from .op_frequence import * from .op_frequence import *
from . import quantize from . import quantize
from .quantize import * from .quantize import *
from . import reader
from .reader import *
from . import slim from . import slim
from .slim import * from .slim import *
from . import utils from . import utils
...@@ -32,5 +34,6 @@ __all__ += decoder.__all__ ...@@ -32,5 +34,6 @@ __all__ += decoder.__all__
__all__ += memory_usage_calc.__all__ __all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__ __all__ += op_frequence.__all__
__all__ += quantize.__all__ __all__ += quantize.__all__
__all__ += reader.__all__
__all__ += slim.__all__ __all__ += slim.__all__
__all__ += utils.__all__ __all__ += utils.__all__
## CTR READER
An multi-thread cpp reader that has the same interface with py_reader. It
uses cpp multi-thread to read file and is much more faster then the Python read
thread in py_reader.
Currently, it support two types of file:
- gzip
- plain text file
and two types of data format:
- cvs data format is :
* label dense_fea,dense_fea sparse_fea,sparse_fea
- the svm data format is :
* label slot1:fea_sign slot2:fea_sign slot1:fea_sign
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import ctr_reader
__all__ = ctr_reader.__all__
...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \ ...@@ -20,6 +20,8 @@ from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name from paddle.fluid.unique_name import generate as unique_name
__all__ = ['ctr_reader']
def monkey_patch_reader_methods(reader): def monkey_patch_reader_methods(reader):
def __get_reader__(): def __get_reader__():
...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader): ...@@ -30,7 +32,11 @@ def monkey_patch_reader_methods(reader):
def reset(): def reset():
return __get_reader__().reset() return __get_reader__().reset()
def start():
return __get_reader__().start()
reader.reset = reset reader.reset = reset
reader.start = start
reader.stop_gradient = True reader.stop_gradient = True
reader.persistable = True reader.persistable = True
return reader return reader
...@@ -44,7 +50,12 @@ def _copy_reader_var_(block, var): ...@@ -44,7 +50,12 @@ def _copy_reader_var_(block, var):
return new_var return new_var
def ctr_reader(feed_data, def ctr_reader(
feed_dict,
file_type, # gzip or plain
file_format, # csv or svm
dense_slot_index,
sparse_slot_index,
capacity, capacity,
thread_num, thread_num,
batch_size, batch_size,
...@@ -67,12 +78,21 @@ def ctr_reader(feed_data, ...@@ -67,12 +78,21 @@ def ctr_reader(feed_data,
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`. Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args: Args:
feed_dict(list(variable)): a list of data variable.
file_type('gzip'|'plain'): the type of the data file
file_format('csv'|'svm'): csv data or svm data format.
cvs data format is :
label dense_fea,dense_fea sparse_fea,sparse_fea
the svm data format is :
label slot1:fea_sign slot2:fea_sign slot1:fea_sign
dense_slot_index(list(int)): the index of dense slots
sparse_slot_index(list(int)): the index of sparse slots
capacity(int): The buffer capacity maintained by :code:`py_reader`. capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(list|tuple): List of tuples which declaring data shapes. thread_num(int): the thread num to read files by cpp reader.
batch_size(list|tuple): List of strs which declaring data type. batch_size(int): batch size of data.
file_list(list|tuple): List of ints which declaring data lod_level. file_list(list(str)): List of file names that need to read.
slots(bool): Whether use double buffer or not. slots(list(int64)): list of slot id.
name(basestring): The prefix Python queue name and Reader name. None will name(string): The prefix Python queue name and Reader name. None will
be generated automatically. be generated automatically.
Returns: Returns:
...@@ -80,7 +100,15 @@ def ctr_reader(feed_data, ...@@ -80,7 +100,15 @@ def ctr_reader(feed_data,
Examples: Examples:
1. The basic usage of :code:`py_reader` is as follows: 1. The basic usage of :code:`ctr_reader` is as follows:
.. code-block:: python
py_reader = fluid.contrib.ctr_reader.ctr_reader(
feed_dict=datas, file_type='plain', file_format='csv',
file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')
""" """
if name is None: if name is None:
queue_name = unique_name('lod_tensor_blocking_queue') queue_name = unique_name('lod_tensor_blocking_queue')
...@@ -90,7 +118,7 @@ def ctr_reader(feed_data, ...@@ -90,7 +118,7 @@ def ctr_reader(feed_data,
reader_name = "_".join([name, "reader"]) reader_name = "_".join([name, "reader"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
reader_var = startup_blk.create_var(name=reader_name) reader_var = startup_blk.create_var(name=reader_name)
...@@ -99,12 +127,22 @@ def ctr_reader(feed_data, ...@@ -99,12 +127,22 @@ def ctr_reader(feed_data,
inputs={'blocking_queue': [queue_name]}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]}, outputs={'Out': [reader_var]},
attrs={ attrs={
'use_data_config': False,
'thread_num': thread_num, 'thread_num': thread_num,
'batch_size': batch_size, 'batch_size': batch_size,
'file_list': file_list, 'file_list': file_list,
'slots': slots, 'file_type': file_type,
'file_format': file_format,
'dense_slot_index': dense_slot_index,
'sparse_slot_index': sparse_slot_index,
'sparse_slots': slots,
'ranks': [],
'lod_levels': [],
'shape_concat': []
}) })
dtypes = [data.dtype for data in feed_dict]
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_( main_prog_reader_var = _copy_reader_var_(
...@@ -118,6 +156,9 @@ def ctr_reader(feed_data, ...@@ -118,6 +156,9 @@ def ctr_reader(feed_data,
main_blk = default_main_program().current_block() main_blk = default_main_program().current_block()
main_blk.append_op( main_blk.append_op(
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) type='read',
inputs={'Reader': [reader]},
attrs={'infer_out': False},
outputs={'Out': feed_dict})
return reader return reader
...@@ -523,7 +523,7 @@ def _py_reader(capacity, ...@@ -523,7 +523,7 @@ def _py_reader(capacity,
double_buffer_name = "_".join([name, "double_buffer"]) double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=reader_name) startup_var = startup_blk.create_var(name=reader_name)
......
...@@ -109,6 +109,7 @@ packages=['paddle', ...@@ -109,6 +109,7 @@ packages=['paddle',
'paddle.fluid.contrib', 'paddle.fluid.contrib',
'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize', 'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.reader',
'paddle.fluid.contrib.slim', 'paddle.fluid.contrib.slim',
'paddle.fluid.contrib.slim.core', 'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.graph',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册