From 3d677b1eca75733adbc1939dd0a50cbacead6718 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 20:29:48 +0800 Subject: [PATCH] fix compile errors and make OpenFilesOpMaker derived from FileReaderMakerBase --- paddle/fluid/operators/reader/CMakeLists.txt | 2 +- .../fluid/operators/reader/open_files_op.cc | 25 ++++++------------- .../operators/reader/reader_op_registry.cc | 16 ++++++++++++ .../operators/reader/reader_op_registry.h | 15 +---------- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 1254783d69a..4a43fc02d21 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,11 +15,11 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) -reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 473c002e93a..6b62e1db490 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase { } }; -class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { +class OpenFilesOpMaker : public FileReaderMakerBase { public: OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + AddComment(R"DOC( OpenFiles Operator An OpenFilesOp creates a MultipleReader, which is able to read data multi-threaded from multiple files. )DOC"); - AddOutput("Out", "(ReaderHolder) The created MultipleReader."); - AddAttr>("shape_concat", - "The concat of all data's shapes."); - AddAttr>( - "ranks", - "The ranks of each data." - "e.g." - "shape_concat = [2,3,4,5,6]" - "ranks = [3,2]" - "It means the reader will generate two data each time," - "whose shapes are [2,3,4] and [5,6] respectively."); - AddAttr>("lod_levels", "The LoD levels of each data."); - AddAttr>("file_names", "Files to be read."); - AddAttr("thread_num", "The maximal concurrent prefetch thread number.") - .GreaterThan(0); } }; @@ -196,4 +185,4 @@ class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { namespace reader = paddle::operators::reader; REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, - reader::OpenFilesOpMaker); \ No newline at end of file + reader::OpenFilesOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 0ba4f385443..05d79c76d5a 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -36,6 +36,22 @@ std::unordered_map& FileReaderRegistry() { return regs; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index feab7c63a3e..dd19b982dad 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) { } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); - PADDLE_ENFORCE_NE(separator_pos, std::string::npos, - "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); - - auto itor = FileReaderRegistry().find(filetype); - PADDLE_ENFORCE(itor != FileReaderRegistry().end(), - "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); - return std::unique_ptr(reader); -} + const std::string& file_name, const std::vector& dims); extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); -- GitLab