diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 1254783d69a87b8b13650449fdb84174f7aef91e..4a43fc02d2189ec2eb4bff12769c11cb8dae8193 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,11 +15,11 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) -reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 473c002e93a6db65c5a47943e8b5c820abd19b34..6b62e1db49076008255c15845bdd1d0dd27d297f 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase { } }; -class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { +class OpenFilesOpMaker : public FileReaderMakerBase { public: OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + AddComment(R"DOC( OpenFiles Operator An OpenFilesOp creates a MultipleReader, which is able to read data multi-threaded from multiple files. )DOC"); - AddOutput("Out", "(ReaderHolder) The created MultipleReader."); - AddAttr>("shape_concat", - "The concat of all data's shapes."); - AddAttr>( - "ranks", - "The ranks of each data." - "e.g." - "shape_concat = [2,3,4,5,6]" - "ranks = [3,2]" - "It means the reader will generate two data each time," - "whose shapes are [2,3,4] and [5,6] respectively."); - AddAttr>("lod_levels", "The LoD levels of each data."); - AddAttr>("file_names", "Files to be read."); - AddAttr("thread_num", "The maximal concurrent prefetch thread number.") - .GreaterThan(0); } }; @@ -196,4 +185,4 @@ class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { namespace reader = paddle::operators::reader; REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, - reader::OpenFilesOpMaker); \ No newline at end of file + reader::OpenFilesOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 0ba4f3854431742eb354f8c90eb395f5d7b32b2e..05d79c76d5ab0e48f441a7cc8a470bd99eb80ca8 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -36,6 +36,22 @@ std::unordered_map& FileReaderRegistry() { return regs; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index feab7c63a3eeea6da78dca4c752c33f76df25a80..dd19b982dad8622c2c9cfd3395e6812acba26982 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) { } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); - PADDLE_ENFORCE_NE(separator_pos, std::string::npos, - "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); - - auto itor = FileReaderRegistry().find(filetype); - PADDLE_ENFORCE(itor != FileReaderRegistry().end(), - "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); - return std::unique_ptr(reader); -} + const std::string& file_name, const std::vector& dims); extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks);