From de9a411f1cdca85f127f0715e1f9d25ef4c76195 Mon Sep 17 00:00:00 2001 From: fengjiayi <fengjiayi@baidu.com> Date: Sun, 8 Jul 2018 15:59:27 +0800 Subject: [PATCH] adjust readers' inheritance relationships 1. Make PyReader and RandomDataGenerator inherited from FileReader. 2. Remove the memeber variable 'dims_' and realated checks in FileReader. --- paddle/fluid/framework/reader.cc | 20 +------------------ paddle/fluid/framework/reader.h | 5 +---- .../reader/create_double_buffer_reader_op.cc | 2 +- .../operators/reader/create_py_reader_op.cc | 7 ++++--- .../reader/create_random_data_generator_op.cc | 6 +++--- .../reader/create_recordio_file_reader_op.cc | 17 +++------------- .../fluid/operators/reader/open_files_op.cc | 9 +++------ .../operators/reader/reader_op_registry.cc | 4 ++-- .../operators/reader/reader_op_registry.h | 11 +++++----- 9 files changed, 23 insertions(+), 58 deletions(-) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 0b36f1116d..f288b90b4d 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -18,24 +18,6 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} - -void FileReader::ReadNext(std::vector<LoDTensor> *out) { - ReadNextImpl(out); - if (out->empty()) { - return; - } - - PADDLE_ENFORCE_EQ(out->size(), dims_.size()); - for (size_t i = 0; i < dims_.size(); ++i) { - auto &actual = (*out)[i].dims(); - auto &expect = dims_[i]; - - PADDLE_ENFORCE_EQ(actual.size(), expect.size()); - for (int j = 0; j < actual.size(); ++j) { - // PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); - } - } -} +void FileReader::ReadNext(std::vector<LoDTensor> *out) { ReadNextImpl(out); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 64d4ceab62..823d58af5e 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -48,15 +48,12 @@ class DecoratedReader : public ReaderBase { class FileReader : public ReaderBase { public: - explicit FileReader(const std::vector<DDim>& dims); + FileReader() : ReaderBase() {} void ReadNext(std::vector<LoDTensor>* out) override; protected: virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0; - - private: - std::vector<DDim> dims_; }; // The ReaderHolder is used as reader' unified wrapper, diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 5f734489a8..0d2ff2e8e4 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -151,8 +151,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { } void DoubleBufferReader::ReInit() { - reader_->ReInit(); EndPrefetcher(); + reader_->ReInit(); StartPrefetcher(); } diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 36587360f7..84ea72379b 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -19,14 +19,15 @@ namespace paddle { namespace operators { namespace reader { -class PyReader : public framework::ReaderBase { +class PyReader : public framework::FileReader { public: - explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) { + explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) + : framework::FileReader() { PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); queue_ = queue; } - void ReadNext(std::vector<framework::LoDTensor>* out) override { + void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { bool success; *out = queue_->Pop(&success); if (!success) out->clear(); diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 5b7e8a063a..7cbc2882fd 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template <typename T> -class RandomDataGenerator : public framework::ReaderBase { +class RandomDataGenerator : public framework::FileReader { public: RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low, float high) - : framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { + : framework::FileReader(), low_(low), high_(high), shapes_(shapes) { PADDLE_ENFORCE_LE(low, high, "'low' shouldn't be greater than 'high'.(%f vs %f)", low, high); @@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { dist_ = std::uniform_real_distribution<float>(low_, high_); } - void ReadNext(std::vector<framework::LoDTensor>* out) override { + void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { out->clear(); out->reserve(shapes_.size()); for (const framework::DDim& shape : shapes_) { diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 559827f084..c032acdffd 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -21,9 +21,8 @@ namespace reader { template <bool ThreadSafe> class RecordIOFileReader : public framework::FileReader { public: - explicit RecordIOFileReader(const std::string& filename, - const std::vector<framework::DDim>& dims) - : FileReader(dims), + explicit RecordIOFileReader(const std::string& filename) + : FileReader(), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) { @@ -58,20 +57,10 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); - const auto& ranks = Attr<std::vector<int>>("ranks"); - PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); - PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), - static_cast<int>(shape_concat.size()), - "The accumulate of all ranks should be equal to the " - "shape concat's length."); std::string filename = Attr<std::string>("filename"); - auto* out = scope.FindVar(Output("Out")) ->template GetMutable<framework::ReaderHolder>(); - - out->Reset(new RecordIOFileReader<true>( - filename, RestoreShapes(shape_concat, ranks))); + out->Reset(new RecordIOFileReader<true>(filename)); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 31e5d81e55..6e34a5bd00 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -23,13 +23,12 @@ namespace reader { class MultiFileReader : public framework::ReaderBase { public: - MultiFileReader(const std::vector<std::string>& file_names, - const std::vector<framework::DDim>& dims, size_t thread_num, + MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num, size_t buffer_size) : buffer_size_(buffer_size) { readers_.reserve(file_names.size()); for (const std::string& f_name : file_names) { - readers_.emplace_back(CreateReaderByFileName(f_name, dims)); + readers_.emplace_back(CreateReaderByFileName(f_name)); } prefetchers_.resize(thread_num); StartNewScheduler(); @@ -180,9 +179,7 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable<framework::ReaderHolder>(); - out->Reset(new MultiFileReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset(new MultiFileReader(file_names, thread_num, buffer_size)); } }; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index e11256a49f..b82aab1214 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { } std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( - const std::string& file_name, const std::vector<framework::DDim>& dims) { + const std::string& file_name) { size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " @@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(file_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name); return std::unique_ptr<framework::ReaderBase>(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 244bf15f06..25c3e7d77b 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -25,22 +25,21 @@ namespace reader { static constexpr char kFileFormatSeparator[] = "."; -using FileReaderCreator = std::function<framework::ReaderBase*( - const std::string&, const std::vector<framework::DDim>&)>; +using FileReaderCreator = + std::function<framework::ReaderBase*(const std::string&)>; std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); template <typename Reader> int RegisterFileReader(const std::string& filetype) { - FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector<framework::DDim>& dims) { - return new Reader(fn, dims); + FileReaderRegistry()[filetype] = [](const std::string& fn) { + return new Reader(fn); }; return 0; } std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( - const std::string& file_name, const std::vector<framework::DDim>& dims); + const std::string& file_name); extern std::vector<framework::DDim> RestoreShapes( const std::vector<int>& shape_concat, const std::vector<int>& ranks); -- GitLab