提交 1f9f8ebb 编写于 作者: R rensilin

LineDataReader_filename_prefix

Change-Id: I47c64499e664cd2d4f403a9fb181333fe291acd6
上级 fae7e884
...@@ -86,35 +86,50 @@ public: ...@@ -86,35 +86,50 @@ public:
} }
_done_file_name = config["done_file"].as<std::string>(); _done_file_name = config["done_file"].as<std::string>();
_buffer_size = config["buffer_size"].as<int>(1024); _buffer_size = config["buffer_size"].as<int>(1024);
_filename_prefix = config["filename_prefix"].as<std::string>("");
_buffer.reset(new char[_buffer_size]); _buffer.reset(new char[_buffer_size]);
return 0; return 0;
} }
//判断样本数据是否已就绪,就绪表明可以开始download //判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) { virtual bool is_data_ready(const std::string& data_dir) {
auto done_file_path = ::paddle::framework::fs_path_join(data_dir, _done_file_name); auto done_file_path = framework::fs_path_join(data_dir, _done_file_name);
if (::paddle::framework::fs_exists(done_file_path)) { if (framework::fs_exists(done_file_path)) {
return true; return true;
} }
return false; return false;
} }
virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
if (_filename_prefix.empty()) {
return framework::fs_list(data_dir);
}
std::vector<std::string> data_files;
for (auto& filepath : framework::fs_list(data_dir)) {
auto filename = framework::fs_path_split(filepath).second;
if (filename.size() >= _filename_prefix.size() && filename.substr(0, _filename_prefix.size()) == _filename_prefix) {
data_files.push_back(std::move(filepath));
}
}
return data_files;
}
//读取数据样本流中 //读取数据样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) { virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
::paddle::framework::ChannelWriter<DataItem> writer(data_channel.get()); framework::ChannelWriter<DataItem> writer(data_channel.get());
DataItem data_item; DataItem data_item;
if (_buffer_size <= 0 || _buffer == nullptr) { if (_buffer_size <= 0 || _buffer == nullptr) {
VLOG(2) << "no buffer"; VLOG(2) << "no buffer";
return -1; return -1;
} }
for (const auto& filename : ::paddle::framework::fs_list(data_dir)) { for (const auto& filepath : data_file_list(data_dir)) {
if (::paddle::framework::fs_path_split(filename).second == _done_file_name) { if (framework::fs_path_split(filepath).second == _done_file_name) {
continue; continue;
} }
int err_no = 0; int err_no = 0;
std::shared_ptr<FILE> fin = ::paddle::framework::fs_open_read(filename, &err_no, _pipeline_cmd); std::shared_ptr<FILE> fin = framework::fs_open_read(filepath, &err_no, _pipeline_cmd);
if (err_no != 0) { if (err_no != 0) {
VLOG(2) << "fail to open file: " << filename << ", with cmd: " << _pipeline_cmd; VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
return -1; return -1;
} }
while (fgets(_buffer.get(), _buffer_size, fin.get())) { while (fgets(_buffer.get(), _buffer_size, fin.get())) {
...@@ -124,7 +139,7 @@ public: ...@@ -124,7 +139,7 @@ public:
writer << std::move(data_item); writer << std::move(data_item);
} }
if (ferror(fin.get()) != 0) { if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filename; VLOG(2) << "fail to read file: " << filepath;
return -1; return -1;
} }
} }
...@@ -144,6 +159,7 @@ private: ...@@ -144,6 +159,7 @@ private:
std::string _done_file_name; // without data_dir std::string _done_file_name; // without data_dir
int _buffer_size = 0; int _buffer_size = 0;
std::unique_ptr<char[]> _buffer; std::unique_ptr<char[]> _buffer;
std::string _filename_prefix;
}; };
REGISTER_CLASS(DataReader, LineDataReader); REGISTER_CLASS(DataReader, LineDataReader);
......
...@@ -59,6 +59,7 @@ public: ...@@ -59,6 +59,7 @@ public:
virtual bool is_data_ready(const std::string& data_dir) = 0; virtual bool is_data_ready(const std::string& data_dir) = 0;
//读取数据样本流中 //读取数据样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0; virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
virtual const DataParser* get_parser() { virtual const DataParser* get_parser() {
return _parser.get(); return _parser.get();
} }
......
...@@ -105,18 +105,16 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -105,18 +105,16 @@ TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
YAML::Node config = YAML::Load("parser:\n" auto config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
config = YAML::Load("parser:\n"
" class: LineDataParser\n" " class: LineDataParser\n"
"pipeline_cmd: cat\n" "pipeline_cmd: cat\n"
"done_file: done_file\n" "done_file: done_file\n"
"buffer_size: 128"); "buffer_size: 128");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr)); ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(2, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]);
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir)); ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file")); std::ofstream fout(framework::fs_path_join(test_data_dir, "done_file"));
...@@ -155,6 +153,40 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -155,6 +153,40 @@ TEST_F(DataReaderTest, LineDataReader) {
ASSERT_FALSE(reader); ASSERT_FALSE(reader);
} }
TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load("parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册