diff --git a/CMakeLists.txt b/CMakeLists.txt index 8dcf9786e36fa8376720c5bac6417ecbd04b27f6..efa68c9ba243af3c7cdca52b915cc14d307ae89f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -214,6 +214,7 @@ if (NOT WIN32) # there is no official support of warpctc, nccl, cupti in windows include(external/warpctc) # download, build, install warpctc include(cupti) +include(external/gzstream) endif (NOT WIN32) if(WITH_DISTRIBUTE) diff --git a/cmake/external/gzstream.cmake b/cmake/external/gzstream.cmake new file mode 100644 index 0000000000000000000000000000000000000000..59d8e932459dd49017cb32b27e5f1919272fe387 --- /dev/null +++ b/cmake/external/gzstream.cmake @@ -0,0 +1,47 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +IF(MOBILE_INFERENCE) + return() +ENDIF() + +include (ExternalProject) + +# NOTE: gzstream is needed when linking with ctr reader. + +SET(GZSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/gzstream) +SET(GZSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gzstream) +SET(GZSTREAM_INCLUDE_DIR "${GZSTREAM_INSTALL_DIR}/include/" CACHE PATH "gzstream include directory." FORCE) + +ExternalProject_Add( + extern_gzstream + GIT_REPOSITORY "https://github.com/jacquesqiao/gzstream.git" + GIT_TAG "" + PREFIX ${GZSTREAM_SOURCES_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_COMMAND make -j8 + INSTALL_COMMAND mkdir -p ${GZSTREAM_INSTALL_DIR}/lib/ && mkdir -p ${GZSTREAM_INSTALL_DIR}/include/ + && cp ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/libgzstream.a ${GZSTREAM_INSTALL_DIR}/lib + && cp -r ${GZSTREAM_SOURCES_DIR}/src/extern_gzstream/gzstream.h ${GZSTREAM_INSTALL_DIR}/include +) + +ADD_LIBRARY(gzstream STATIC IMPORTED GLOBAL) +SET_PROPERTY(TARGET gzstream PROPERTY IMPORTED_LOCATION + "${GZSTREAM_INSTALL_DIR}/lib/libgzstream.a") + +include_directories(${GZSTREAM_INCLUDE_DIR}) +ADD_DEPENDENCIES(gzstream extern_gzstream zlib) diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 6c919ee1782ebce6d56f7530daa9b748dfb26c47..7c284312df912ad758f6fffc44f111dfe765feb8 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -28,6 +28,12 @@ reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc) +if (NOT WIN32 AND NOT ON_INFER) + cc_library(ctr_reader SRCS ctr_reader.cc DEPS gzstream reader zlib) + cc_test(ctr_reader_test SRCS ctr_reader_test.cc DEPS ctr_reader) + reader_library(create_ctr_reader_op SRCS create_ctr_reader_op.cc DEPS ctr_reader) +endif () + cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent # set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_ctr_reader_op.cc b/paddle/fluid/operators/reader/create_ctr_reader_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..58a465d87a8c0da50e3eb80fefe32d50217f6990 --- /dev/null +++ b/paddle/fluid/operators/reader/create_ctr_reader_op.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class CreateCTRReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + if (out->Get() != nullptr) return; + + const std::string& queue_name = Input("blocking_queue"); + auto* queue_holder_var = scope.FindVar(queue_name); + PADDLE_ENFORCE_NOT_NULL( + queue_holder_var, + "No LoDTensorBlockingQueueHolder variable with name %s found", + queue_name); + auto* queue_holder = + queue_holder_var->template GetMutable(); + + int thread_num = Attr("thread_num"); + std::vector slots = Attr>("slots"); + int batch_size = Attr("batch_size"); + std::vector file_list = + Attr>("file_list"); + out->Reset(std::make_shared(queue_holder->GetQueue(), batch_size, + thread_num, slots, file_list)); + } +}; + +class CreateCTRReaderOpMaker : public FileReaderMakerBase { + protected: + void Apply() override { + AddInput("blocking_queue", + "Name of the `LoDTensorBlockingQueueHolder` variable"); + AddAttr("thread_num", "the thread num to read data"); + AddAttr("batch_size", "the batch size of read data"); + AddAttr>("file_list", + "The list of files that need to read"); + AddAttr>( + "slots", "the slots that should be extract from file"); + + AddComment(R"DOC( + Create CTRReader to support read ctr data with cpp. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = ::paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(create_ctr_reader, reader::CreateCTRReaderOp, + reader::CreateCTRReaderOpMaker); diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1d3ddc89dc09a185e6a41274cf382b430ec3eeb --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -0,0 +1,238 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace paddle { +namespace operators { +namespace reader { + +static inline void string_split(const std::string& s, const char delimiter, + std::vector* output) { + size_t start = 0; + size_t end = s.find_first_of(delimiter); + + while (end <= std::string::npos) { + output->emplace_back(s.substr(start, end - start)); + if (end == std::string::npos) { + break; + } + start = end + 1; + end = s.find_first_of(delimiter, start); + } +} + +static inline void parse_line( + const std::string& line, + const std::unordered_map& slot_to_index, + int64_t* label, + std::unordered_map>* slot_to_data) { + std::vector ret; + string_split(line, ' ', &ret); + *label = std::stoi(ret[2]) > 0; + + for (size_t i = 3; i < ret.size(); ++i) { + const std::string& item = ret[i]; + std::vector feasign_and_slot; + string_split(item, ':', &feasign_and_slot); + if (feasign_and_slot.size() == 2 && + slot_to_index.find(feasign_and_slot[1]) != slot_to_index.end()) { + int64_t feasign = std::strtoll(feasign_and_slot[0].c_str(), NULL, 10); + (*slot_to_data)[feasign_and_slot[1]].push_back(feasign); + } + } + + // NOTE:: if the slot has no value, then fill [0] as it's data. + for (auto& item : slot_to_index) { + if (slot_to_data->find(item.first) == slot_to_data->end()) { + (*slot_to_data)[item.first].push_back(0); + } + } +} + +class Reader { + public: + virtual ~Reader() {} + virtual bool HasNext() = 0; + virtual void NextLine(std::string* line) = 0; +}; + +class GzipReader : public Reader { + public: + explicit GzipReader(const std::string& file_name) + : gzstream_(file_name.c_str()) {} + + ~GzipReader() {} + + bool HasNext() override { return gzstream_.peek() != EOF; } + + void NextLine(std::string* line) override { std::getline(gzstream_, *line); } + + private: + igzstream gzstream_; +}; + +class MultiGzipReader : public Reader { + public: + explicit MultiGzipReader(const std::vector& file_list) { + for (auto& file : file_list) { + readers_.emplace_back(std::make_shared(file)); + } + } + + bool HasNext() override { + if (current_reader_index_ >= readers_.size()) { + return false; + } + if (!readers_[current_reader_index_]->HasNext()) { + current_reader_index_++; + return HasNext(); + } + return true; + } + + void NextLine(std::string* line) override { + readers_[current_reader_index_]->NextLine(line); + } + + private: + std::vector> readers_; + size_t current_reader_index_ = 0; +}; + +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue) { + VLOG(30) << "monitor thread in"; + bool reader_thread_is_running = true; + while (reader_thread_is_running) { + VLOG(30) << "reader_thread_is_running"; + reader_thread_is_running = false; + for (size_t i = 0; i < (*thread_status).size(); ++i) { + if ((*thread_status)[i] == Running) { + VLOG(30) << "reader is running!"; + reader_thread_is_running = true; + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + VLOG(30) << "all reader thread is stopped, push empty data into queue"; + queue->Push({}); + VLOG(30) << "monitor thread exited"; +} + +void ReadThread(const std::vector& file_list, + const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, + std::shared_ptr queue) { + VLOG(30) << "[" << thread_id << "]" + << " reader thread start! thread_id = " << thread_id; + for (auto& file : file_list) { + VLOG(30) << "[" << thread_id << "]" + << " file " << file; + } + (*thread_status)[thread_id] = Running; + VLOG(30) << "set status to running"; + + std::unordered_map slot_to_index; + for (size_t i = 0; i < slots.size(); ++i) { + slot_to_index[slots[i]] = i; + } + + std::string line; + + std::vector>> batch_data; + std::vector batch_label; + + MultiGzipReader reader(file_list); + + VLOG(30) << "reader inited"; + + while (reader.HasNext()) { + batch_data.clear(); + batch_data.reserve(batch_size); + + batch_label.clear(); + batch_label.reserve(batch_size); + + // read batch_size data + for (int i = 0; i < batch_size; ++i) { + if (reader.HasNext()) { + reader.NextLine(&line); + std::unordered_map> slot_to_data; + int64_t label; + parse_line(line, slot_to_index, &label, &slot_to_data); + batch_data.push_back(slot_to_data); + batch_label.push_back(label); + } else { + break; + } + } + + std::vector lod_datas; + + // first insert tensor for each slots + for (auto& slot : slots) { + std::vector lod_data{0}; + std::vector batch_feasign; + + for (size_t i = 0; i < batch_data.size(); ++i) { + auto& feasign = batch_data[i][slot]; + lod_data.push_back(lod_data.back() + feasign.size()); + batch_feasign.insert(batch_feasign.end(), feasign.begin(), + feasign.end()); + } + + framework::LoDTensor lod_tensor; + framework::LoD lod{lod_data}; + lod_tensor.set_lod(lod); + int64_t* tensor_data = lod_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_feasign.size())}), + platform::CPUPlace()); + memcpy(tensor_data, batch_feasign.data(), + batch_feasign.size() * sizeof(int64_t)); + lod_datas.push_back(lod_tensor); + } + + // insert label tensor + framework::LoDTensor label_tensor; + auto* label_tensor_data = label_tensor.mutable_data( + framework::make_ddim({1, static_cast(batch_label.size())}), + platform::CPUPlace()); + memcpy(label_tensor_data, batch_label.data(), + batch_label.size() * sizeof(int64_t)); + lod_datas.push_back(label_tensor); + + queue->Push(lod_datas); + VLOG(40) << "push one data, queue_size=" << queue->Size(); + } + + (*thread_status)[thread_id] = Stopped; + VLOG(30) << "set status to stopped, thread " << thread_id << " exited"; +} + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..9b2a11bae12d242880829628faa089e1638424b0 --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -0,0 +1,133 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" + +namespace paddle { +namespace operators { +namespace reader { + +enum ReaderThreadStatus { Running, Stopped }; + +void ReadThread(const std::vector& file_list, + const std::vector& slots, int batch_size, + int thread_id, std::vector* thread_status, + std::shared_ptr queue); + +// monitor all running thread, if they are all stopped, +// then push an empty data into LoDTensorBlockingQueue +void MonitorThread(std::vector* thread_status, + std::shared_ptr queue); + +class CTRReader : public framework::FileReader { + public: + explicit CTRReader(const std::shared_ptr& queue, + int batch_size, int thread_num, + const std::vector& slots, + const std::vector& file_list) + : batch_size_(batch_size), slots_(slots), file_list_(file_list) { + PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); + PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); + PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); + thread_num_ = + file_list_.size() > thread_num ? thread_num : file_list_.size(); + queue_ = queue; + SplitFiles(); + for (size_t i = 0; i < thread_num_; ++i) { + read_thread_status_.push_back(Stopped); + } + } + + ~CTRReader() {} + + void ReadNext(std::vector* out) override { + bool success; + *out = queue_->Pop(&success); + if (!success) out->clear(); + } + + void Shutdown() override { + VLOG(3) << "Shutdown reader"; + if (status_ == ReaderStatus::kStopped) { + return; + } + // shutdown should stop all the reader thread + for (auto& read_thread : read_threads_) { + read_thread->join(); + } + monitor_thread_->join(); + + read_threads_.clear(); + monitor_thread_.reset(nullptr); + queue_->Close(); + status_ = ReaderStatus::kStopped; + } + + void Start() override { + VLOG(3) << "Start reader"; + PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!"); + queue_->ReOpen(); + VLOG(3) << "reopen success"; + VLOG(3) << "thread_num " << thread_num_; + for (int thread_id = 0; thread_id < thread_num_; thread_id++) { + read_threads_.emplace_back(new std::thread( + std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, + thread_id, &read_thread_status_, queue_))); + } + monitor_thread_.reset(new std::thread( + std::bind(&MonitorThread, &read_thread_status_, queue_))); + status_ = ReaderStatus::kRunning; + } + + private: + void SplitFiles() { + file_groups_.resize(thread_num_); + for (size_t i = 0; i < file_list_.size(); ++i) { + auto& file_name = file_list_[i]; + std::ifstream f(file_name.c_str()); + PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name); + file_groups_[i % thread_num_].push_back(file_name); + } + } + + private: + size_t thread_num_; + const int batch_size_; + const std::vector slots_; + const std::vector file_list_; + std::shared_ptr queue_; + std::vector> read_threads_; + std::unique_ptr monitor_thread_; + std::vector read_thread_status_; + std::vector> file_groups_; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8dba9baebce0a82ee2a541fe6ae9f6bcef8e2835 --- /dev/null +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/ctr_reader.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/reader/blocking_queue.h" + +using paddle::operators::reader::LoDTensorBlockingQueue; +using paddle::operators::reader::LoDTensorBlockingQueueHolder; +using paddle::operators::reader::CTRReader; +using paddle::framework::LoDTensor; +using paddle::framework::LoD; +using paddle::framework::DDim; +using paddle::platform::CPUPlace; +using paddle::framework::make_ddim; + +static void generatedata(const std::vector& data, + const std::string& file_name) { + std::ifstream in(file_name.c_str()); + if (in.good()) { + VLOG(3) << "file " << file_name << " exist, delete it first!"; + remove(file_name.c_str()); + } else { + in.close(); + } + + ogzstream out(file_name.c_str()); + PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name); + for (auto& c : data) { + out << c; + } + out.close(); + PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name); +} + +static inline void check_all_data( + const std::vector& ctr_data, + const std::vector& slots, const std::vector& label_dims, + const std::vector& label_value, + const std::vector>>& data_slot_6002, + const std::vector>>& data_slot_6003, + size_t batch_num, size_t batch_size, + std::shared_ptr queue, CTRReader* reader) { + std::vector out; + for (size_t i = 0; i < batch_num; ++i) { + reader->ReadNext(&out); + ASSERT_EQ(out.size(), slots.size() + 1); + auto& label_tensor = out.back(); + ASSERT_EQ(label_tensor.dims(), label_dims[i]); + for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size(); + ++j) { + auto& label = label_tensor.data()[j]; + ASSERT_TRUE(label == 0 || label == 1); + ASSERT_EQ(label, label_value[i * batch_size + j]); + } + auto& tensor_6002 = out[0]; + ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod()); + ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(), + tensor_6002.data(), + tensor_6002.dims()[1] * sizeof(int64_t)), + 0); + } + reader->ReadNext(&out); + ASSERT_EQ(out.size(), 0); + ASSERT_EQ(queue->Size(), 0); +} + +TEST(CTR_READER, read_data) { + const std::vector ctr_data = { + "aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n", + "bbbb 1 0 5:6003 6:6003 7:6003 8:6004 9:6004 -1\n", + "cccc 1 1 10:6002 11:6002 12:6002 13:6002 14:6002 -2\n", + "dddd 1 0 15:6003 16:6003 17:6003 18:6003 19:6004 -3\n", + "1111 1 1 20:6001 21:6001 22:6001 23:6001 24:6001 12\n", + "2222 1 1 25:6004 26:6004 27:6004 28:6005 29:6005 aa\n", + "3333 1 0 30:6002 31:6003 32:6004 33:6004 34:6005 er\n", + "eeee 1 1 35:6003 36:6003 37:6005 38:6005 39:6005 dd\n", + "ffff 1 1 40:6002 41:6003 42:6004 43:6004 44:6005 66\n", + "gggg 1 1 46:6006 45:6006 47:6003 48:6003 49:6003 ba\n", + }; + std::string gz_file_name = "test_ctr_reader_data.gz"; + generatedata(ctr_data, gz_file_name); + + std::vector label_value = {0, 0, 1, 0, 1, 1, 0, 1, 1, 1}; + + std::tuple> a1({{0, 1, 2, 7}}, + {0, 0, 10, 11, 12, 13, 14}); + std::tuple> a2({{0, 1, 2, 3}}, {0, 0, 0}); + std::tuple> a3({{0, 1, 2, 3}}, {30, 0, 40}); + std::tuple> a4({{0, 1}}, {0}); + std::vector>> data_slot_6002{a1, a2, a3, + a4}; + + std::tuple> b1({{0, 1, 4, 5}}, {1, 5, 6, 7, 0}); + std::tuple> b2({{0, 4, 5, 6}}, + {15, 16, 17, 18, 0, 0}); + std::tuple> b3({{0, 1, 3, 4}}, {31, 35, 36, 41}); + std::tuple> b4({{0, 3}}, {47, 48, 49}); + std::vector>> data_slot_6003{b1, b2, b3, + b4}; + + std::vector label_dims = {{1, 3}, {1, 3}, {1, 3}, {1, 1}}; + + LoDTensorBlockingQueueHolder queue_holder; + int capacity = 64; + queue_holder.InitOnce(capacity, {}, false); + + std::shared_ptr queue = queue_holder.GetQueue(); + + int batch_size = 3; + int thread_num = 1; + std::vector slots = {"6002", "6003"}; + std::vector file_list; + for (int i = 0; i < thread_num; ++i) { + file_list.push_back(gz_file_name); + } + + CTRReader reader(queue, batch_size, thread_num, slots, file_list); + + reader.Start(); + size_t batch_num = + std::ceil(static_cast(ctr_data.size()) / batch_size) * thread_num; + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); + + reader.Shutdown(); + + reader.Start(); + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); + reader.Shutdown(); +} diff --git a/python/paddle/fluid/contrib/reader/ctr_reader.py b/python/paddle/fluid/contrib/reader/ctr_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..b8449e8d848670f8262aa01e5654e0e2fc621837 --- /dev/null +++ b/python/paddle/fluid/contrib/reader/ctr_reader.py @@ -0,0 +1,123 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid import core +from paddle.fluid.executor import global_scope +from paddle.fluid.framework import default_main_program, \ + default_startup_program, Variable +from paddle.fluid.unique_name import generate as unique_name + + +def monkey_patch_reader_methods(reader): + def __get_reader__(): + scope = global_scope() + var = scope.find_var(reader.name) + return var.get_reader() + + def reset(): + return __get_reader__().reset() + + reader.reset = reset + reader.stop_gradient = True + reader.persistable = True + return reader + + +def _copy_reader_var_(block, var): + new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) + new_var.desc.set_shapes(var.desc.shapes()) + new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.persistable = True + return new_var + + +def ctr_reader(feed_data, + capacity, + thread_num, + batch_size, + file_list, + slots, + name=None): + """ + Create a CTR reader for data feeding in Python + + This layer returns a Reader Variable. + The Reader provides :code:`decorate_paddle_reader()` and + :code:`decorate_tensor_provider()` to set a Python generator as the data + source in Python side. When :code:`Executor::Run()` is invoked in C++ + side, the data from the generator would be read automatically. Unlike + :code:`DataFeeder.feed()`, the data reading process and + :code:`Executor::Run()` process can run in parallel using + :code:`py_reader`. The :code:`start()` method of the Reader should be + called when each pass begins, while the :code:`reset()` method should be + called when the pass ends and :code:`fluid.core.EOFException` raises. + Note that :code:`Program.clone()` method cannot clone :code:`py_reader`. + + Args: + capacity(int): The buffer capacity maintained by :code:`py_reader`. + thread_num(list|tuple): List of tuples which declaring data shapes. + batch_size(list|tuple): List of strs which declaring data type. + file_list(list|tuple): List of ints which declaring data lod_level. + slots(bool): Whether use double buffer or not. + name(basestring): The prefix Python queue name and Reader name. None will + be generated automatically. + + Returns: + Variable: A Reader from which we can get feeding data. + + Examples: + + 1. The basic usage of :code:`py_reader` is as follows: + """ + if name is None: + queue_name = unique_name('lod_tensor_blocking_queue') + reader_name = unique_name('create_ctr_reader') + else: + queue_name = "_".join([name, "queue"]) + reader_name = "_".join([name, "reader"]) + + var = global_scope().var(queue_name) + feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) + + startup_blk = default_startup_program().current_block() + reader_var = startup_blk.create_var(name=reader_name) + startup_blk.append_op( + type='create_ctr_reader', + inputs={'blocking_queue': [queue_name]}, + outputs={'Out': [reader_var]}, + attrs={ + 'thread_num': thread_num, + 'batch_size': batch_size, + 'file_list': file_list, + 'slots': slots, + }) + + reader_var.persistable = True + + main_prog_reader_var = _copy_reader_var_( + default_main_program().current_block(), reader_var) + + reader = monkey_patch_reader_methods(main_prog_reader_var) + + # monkey patch py_reader special methods + reader.queue = feed_queue + reader.exited = False + + main_blk = default_main_program().current_block() + main_blk.append_op( + type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data}) + + return reader