diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 0b36f1116d15004b355e854e101abb9ad3297836..f288b90b4dbaff9b286b7b74d8cdf7d5a5963650 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -18,24 +18,6 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} - -void FileReader::ReadNext(std::vector<LoDTensor> *out) { - ReadNextImpl(out); - if (out->empty()) { - return; - } - - PADDLE_ENFORCE_EQ(out->size(), dims_.size()); - for (size_t i = 0; i < dims_.size(); ++i) { - auto &actual = (*out)[i].dims(); - auto &expect = dims_[i]; - - PADDLE_ENFORCE_EQ(actual.size(), expect.size()); - for (int j = 0; j < actual.size(); ++j) { - // PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); - } - } -} +void FileReader::ReadNext(std::vector<LoDTensor> *out) { ReadNextImpl(out); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 64d4ceab624312ed366d7e835072899f1f033a88..823d58af5ea010c797dd2078759509fbe30f3a6a 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -48,15 +48,12 @@ class DecoratedReader : public ReaderBase { class FileReader : public ReaderBase { public: - explicit FileReader(const std::vector<DDim>& dims); + FileReader() : ReaderBase() {} void ReadNext(std::vector<LoDTensor>* out) override; protected: virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0; - - private: - std::vector<DDim> dims_; }; // The ReaderHolder is used as reader' unified wrapper, diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 5f734489a81764875988f440696682570ff4d1d7..0d2ff2e8e4882dc8c6572d173e85e695c42898b4 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -151,8 +151,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { } void DoubleBufferReader::ReInit() { - reader_->ReInit(); EndPrefetcher(); + reader_->ReInit(); StartPrefetcher(); } diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 36587360f7347a10e01d4e994482027d9a9bb5d0..84ea72379b406da2c91504c985a523ce663cc75d 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -19,14 +19,15 @@ namespace paddle { namespace operators { namespace reader { -class PyReader : public framework::ReaderBase { +class PyReader : public framework::FileReader { public: - explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) { + explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) + : framework::FileReader() { PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); queue_ = queue; } - void ReadNext(std::vector<framework::LoDTensor>* out) override { + void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { bool success; *out = queue_->Pop(&success); if (!success) out->clear(); diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 5b7e8a063a034f0be056065826fca0fe807bc9a7..7cbc2882fdd6aa4ab80c9360014343d670734749 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template <typename T> -class RandomDataGenerator : public framework::ReaderBase { +class RandomDataGenerator : public framework::FileReader { public: RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low, float high) - : framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { + : framework::FileReader(), low_(low), high_(high), shapes_(shapes) { PADDLE_ENFORCE_LE(low, high, "'low' shouldn't be greater than 'high'.(%f vs %f)", low, high); @@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { dist_ = std::uniform_real_distribution<float>(low_, high_); } - void ReadNext(std::vector<framework::LoDTensor>* out) override { + void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { out->clear(); out->reserve(shapes_.size()); for (const framework::DDim& shape : shapes_) { diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 559827f08494af6730aafa1e67c46a47c21dedf6..c032acdffdaacc15d99dd36048d32faf0468e254 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -21,9 +21,8 @@ namespace reader { template <bool ThreadSafe> class RecordIOFileReader : public framework::FileReader { public: - explicit RecordIOFileReader(const std::string& filename, - const std::vector<framework::DDim>& dims) - : FileReader(dims), + explicit RecordIOFileReader(const std::string& filename) + : FileReader(), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) { @@ -58,20 +57,10 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); - const auto& ranks = Attr<std::vector<int>>("ranks"); - PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); - PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), - static_cast<int>(shape_concat.size()), - "The accumulate of all ranks should be equal to the " - "shape concat's length."); std::string filename = Attr<std::string>("filename"); - auto* out = scope.FindVar(Output("Out")) ->template GetMutable<framework::ReaderHolder>(); - - out->Reset(new RecordIOFileReader<true>( - filename, RestoreShapes(shape_concat, ranks))); + out->Reset(new RecordIOFileReader<true>(filename)); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 31e5d81e55ed9703eb3a9ef2595fa2a280f1a734..6e34a5bd008a0d24df25a93a49a965c31cbb758c 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -23,13 +23,12 @@ namespace reader { class MultiFileReader : public framework::ReaderBase { public: - MultiFileReader(const std::vector<std::string>& file_names, - const std::vector<framework::DDim>& dims, size_t thread_num, + MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num, size_t buffer_size) : buffer_size_(buffer_size) { readers_.reserve(file_names.size()); for (const std::string& f_name : file_names) { - readers_.emplace_back(CreateReaderByFileName(f_name, dims)); + readers_.emplace_back(CreateReaderByFileName(f_name)); } prefetchers_.resize(thread_num); StartNewScheduler(); @@ -180,9 +179,7 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable<framework::ReaderHolder>(); - out->Reset(new MultiFileReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset(new MultiFileReader(file_names, thread_num, buffer_size)); } }; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index e11256a49ffa6adc9410376cc8a71fa017df7e9c..b82aab1214992be73d876a42424234e3cea46455 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { } std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( - const std::string& file_name, const std::vector<framework::DDim>& dims) { + const std::string& file_name) { size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " @@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(file_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name); return std::unique_ptr<framework::ReaderBase>(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 244bf15f068a47efc29ee54492cdbdeb10025020..25c3e7d77b788d38daf6dee1fc79e5c1c97e8842 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -25,22 +25,21 @@ namespace reader { static constexpr char kFileFormatSeparator[] = "."; -using FileReaderCreator = std::function<framework::ReaderBase*( - const std::string&, const std::vector<framework::DDim>&)>; +using FileReaderCreator = + std::function<framework::ReaderBase*(const std::string&)>; std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); template <typename Reader> int RegisterFileReader(const std::string& filetype) { - FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector<framework::DDim>& dims) { - return new Reader(fn, dims); + FileReaderRegistry()[filetype] = [](const std::string& fn) { + return new Reader(fn); }; return 0; } std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( - const std::string& file_name, const std::vector<framework::DDim>& dims); + const std::string& file_name); extern std::vector<framework::DDim> RestoreShapes( const std::vector<int>& shape_concat, const std::vector<int>& ranks);