提交 3d677b1e 编写于 作者: F fengjiayi

fix compile errors and make OpenFilesOpMaker derived from FileReaderMakerBase

上级 55062252
...@@ -15,11 +15,11 @@ function(reader_library TARGET_NAME) ...@@ -15,11 +15,11 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(open_files_op SRCS open_files_op.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...@@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase {
} }
}; };
class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { class OpenFilesOpMaker : public FileReaderMakerBase {
public: public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files. read data multi-threaded from multiple files.
)DOC"); )DOC");
AddOutput("Out", "(ReaderHolder) The created MultipleReader.");
AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
} }
}; };
......
...@@ -36,6 +36,22 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { ...@@ -36,6 +36,22 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs; return regs;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
std::string filetype = file_name.substr(0, separator_pos);
std::string f_name = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(f_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase( FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
......
...@@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) { ...@@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) {
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) { const std::string& file_name, const std::vector<framework::DDim>& dims);
size_t separator_pos = file_name.find(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
std::string filetype = file_name.substr(0, separator_pos);
std::string f_name = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(f_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册